Skip to main content

qail_pg/driver/
explain.rs

1//! EXPLAIN-based query cost estimation for pre-check rejection.
2//!
3//! Provides runtime cost-based rejection of queries that would be too
4//! expensive, using PostgreSQL's `EXPLAIN (FORMAT JSON)` output.
5//!
6//! # Modes
7//! - **Off**: No EXPLAIN pre-check
8//! - **Precheck**: Run EXPLAIN on cache-miss for queries with expand depth ≥ threshold
9//! - **Enforce**: Always run EXPLAIN and enforce cost thresholds
10//!
11//! # Caching
12//! EXPLAIN results are cached by `AST_shape_hash + rls_signature` with configurable TTL.
13//! This avoids repeated EXPLAIN calls for the same query shape.
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19/// Configuration for EXPLAIN pre-check behavior.
20#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22    /// Operating mode for EXPLAIN pre-check.
23    pub mode: ExplainMode,
24
25    /// Run EXPLAIN for queries with expand_depth >= this value.
26    /// Default: 3 (queries joining 3+ tables get pre-checked).
27    pub depth_threshold: usize,
28
29    /// Reject if PostgreSQL's estimated total cost exceeds this.
30    /// Default: 100,000 (unitless PostgreSQL planner cost).
31    pub max_total_cost: f64,
32
33    /// Reject if PostgreSQL estimates more rows than this.
34    /// Default: 1,000,000 rows.
35    pub max_plan_rows: u64,
36
37    /// TTL for cached EXPLAIN results.
38    /// Default: 5 minutes.
39    pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43    fn default() -> Self {
44        Self {
45            mode: ExplainMode::Precheck,
46            depth_threshold: 3,
47            max_total_cost: 100_000.0,
48            max_plan_rows: 1_000_000,
49            cache_ttl: Duration::from_secs(300),
50        }
51    }
52}
53
54/// Operating mode for EXPLAIN pre-check.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57    /// No EXPLAIN pre-check — fastest, no protection.
58    Off,
59    /// Run EXPLAIN on cache-miss for queries above depth threshold.
60    /// Recommended default for production.
61    Precheck,
62    /// Always run EXPLAIN and enforce — strictest, slight latency cost.
63    /// Recommended for staging or high-security tenants.
64    Enforce,
65}
66
67/// Result of an EXPLAIN pre-check.
68#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70    /// PostgreSQL's estimated total cost (arbitrary units).
71    pub total_cost: f64,
72    /// PostgreSQL's estimated number of rows returned.
73    pub plan_rows: u64,
74}
75
76/// Cached EXPLAIN result with TTL and row-estimate snapshot.
77struct CachedEstimate {
78    estimate: ExplainEstimate,
79    cached_at: Instant,
80    /// Row estimate snapshot at cache time, for drift detection.
81    plan_rows: u64,
82}
83
84/// In-memory cache for EXPLAIN estimates, keyed by AST shape hash.
85pub struct ExplainCache {
86    entries: Mutex<HashMap<u64, CachedEstimate>>,
87    ttl: Duration,
88    /// Maximum number of cached entries to prevent OOM from shape explosion
89    max_entries: usize,
90}
91
92impl ExplainCache {
93    /// Create a new EXPLAIN cache with the given TTL.
94    pub fn new(ttl: Duration) -> Self {
95        Self {
96            entries: Mutex::new(HashMap::new()),
97            ttl,
98            max_entries: 10_000,
99        }
100    }
101
102    /// Get a cached estimate if it exists, hasn't expired, and row-estimate
103    /// hasn't drifted beyond 50%.
104    ///
105    /// `current_reltuples` is the current `pg_class.reltuples` for the primary
106    /// table. If provided and the cached plan_rows have drifted >50% from
107    /// the current estimate, the entry is considered stale (data skew).
108    pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109        let entries = self.entries.lock().ok()?;
110        let entry = entries.get(&shape_hash)?;
111        if entry.cached_at.elapsed() < self.ttl {
112            // Row-estimate drift check: invalidate if BOTH conditions met:
113            // 1. Relative change > 50% (data skew)
114            // 2. Absolute delta > 10,000 rows (prevents small table thrash)
115            if let Some(current) = current_reltuples
116                && entry.plan_rows > 0
117            {
118                let cached = entry.plan_rows as f64;
119                let drift = ((current as f64) - cached).abs() / cached;
120                let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121                if drift > 0.5 && abs_delta > 10_000 {
122                    return None; // Stale — significant data skew detected
123                }
124            }
125            Some(entry.estimate.clone())
126        } else {
127            None
128        }
129    }
130
131    /// Store an estimate in the cache.
132    pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133        if let Ok(mut entries) = self.entries.lock() {
134            // Evict expired entries when approaching capacity
135            if entries.len() >= self.max_entries / 2 {
136                let ttl = self.ttl;
137                entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138            }
139            // Hard cap: if still at capacity after eviction, skip insert
140            if entries.len() >= self.max_entries {
141                return;
142            }
143            entries.insert(
144                shape_hash,
145                CachedEstimate {
146                    plan_rows: estimate.plan_rows,
147                    estimate,
148                    cached_at: Instant::now(),
149                },
150            );
151        }
152    }
153
154    /// Number of cached entries (for metrics).
155    pub fn len(&self) -> usize {
156        self.entries.lock().map(|e| e.len()).unwrap_or(0)
157    }
158
159    /// Whether the cache is empty.
160    pub fn is_empty(&self) -> bool {
161        self.len() == 0
162    }
163}
164
165/// Parse `EXPLAIN (FORMAT JSON)` output to extract cost estimates.
166///
167/// Uses lightweight string parsing to avoid adding serde_json as a
168/// dependency to the pg driver crate. The EXPLAIN JSON format is stable:
169/// ```json
170/// [{"Plan": {"Total Cost": 1234.56, "Plan Rows": 5000, ...}}]
171/// ```
172pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
173    let total_cost = extract_json_number(json_str, "Total Cost")?;
174    if !total_cost.is_finite() || total_cost < 0.0 {
175        return None;
176    }
177
178    let plan_rows = extract_json_number(json_str, "Plan Rows")?;
179    if !plan_rows.is_finite() || plan_rows < 0.0 || plan_rows > u64::MAX as f64 {
180        return None;
181    }
182    let plan_rows = plan_rows as u64;
183
184    Some(ExplainEstimate {
185        total_cost,
186        plan_rows,
187    })
188}
189
190/// Extract a numeric value after `"key":` from a JSON string.
191fn extract_json_number(json: &str, key: &str) -> Option<f64> {
192    let pattern = format!("\"{}\":", key);
193    let start = json.find(&pattern)?;
194    let after_key = &json[start + pattern.len()..];
195
196    // Skip whitespace
197    let trimmed = after_key.trim_start();
198
199    // Parse the number (may be integer or float)
200    let end = trimmed.find(|c: char| {
201        !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
202    })?;
203    let num_str = &trimmed[..end];
204    num_str.parse::<f64>().ok()
205}
206
207/// Decision from the EXPLAIN pre-check.
208#[derive(Debug)]
209pub enum ExplainDecision {
210    /// Query is allowed to proceed.
211    Allow,
212    /// Query is rejected with an explanation.
213    Reject {
214        /// PostgreSQL's estimated total cost for the query.
215        total_cost: f64,
216        /// PostgreSQL's estimated row count.
217        plan_rows: u64,
218        /// Configured maximum cost threshold.
219        max_cost: f64,
220        /// Configured maximum row threshold.
221        max_rows: u64,
222    },
223    /// EXPLAIN was skipped (mode=Off or below depth threshold).
224    Skipped,
225}
226
227impl ExplainDecision {
228    /// Returns true if the query should be rejected.
229    pub fn is_rejected(&self) -> bool {
230        matches!(self, ExplainDecision::Reject { .. })
231    }
232
233    /// Human-readable rejection message for API responses.
234    pub fn rejection_message(&self) -> Option<String> {
235        match self {
236            ExplainDecision::Reject {
237                total_cost,
238                plan_rows,
239                max_cost,
240                max_rows,
241            } => Some(format!(
242                "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
243                     or estimated rows {} exceeds limit {}. \
244                     Try narrowing your filters, reducing ?expand depth, or using pagination.",
245                total_cost, max_cost, plan_rows, max_rows
246            )),
247            _ => None,
248        }
249    }
250
251    /// Machine-readable rejection detail for structured API error responses.
252    ///
253    /// Returns `None` for `Allow` and `Skipped` decisions.
254    /// Client SDKs can use this to programmatically react to cost rejections.
255    pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
256        match self {
257            ExplainDecision::Reject {
258                total_cost,
259                plan_rows,
260                max_cost,
261                max_rows,
262            } => Some(ExplainRejectionDetail {
263                estimated_cost: *total_cost,
264                cost_limit: *max_cost,
265                estimated_rows: *plan_rows,
266                row_limit: *max_rows,
267                suggestions: vec![
268                    "Add WHERE clauses to narrow the result set".to_string(),
269                    "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
270                    "Use ?limit and ?offset for pagination".to_string(),
271                    "Add indexes on frequently filtered columns".to_string(),
272                ],
273            }),
274            _ => None,
275        }
276    }
277}
278
279/// Structured rejection detail for EXPLAIN cost guard violations.
280#[derive(Debug, Clone)]
281pub struct ExplainRejectionDetail {
282    /// PostgreSQL's estimated total cost for the query.
283    pub estimated_cost: f64,
284    /// Configured maximum cost threshold.
285    pub cost_limit: f64,
286    /// PostgreSQL's estimated row count.
287    pub estimated_rows: u64,
288    /// Configured maximum row threshold.
289    pub row_limit: u64,
290    /// Actionable suggestions to bring the query under limits.
291    pub suggestions: Vec<String>,
292}
293
294/// Check an EXPLAIN estimate against configured thresholds.
295pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
296    if !estimate.total_cost.is_finite()
297        || !config.max_total_cost.is_finite()
298        || estimate.total_cost > config.max_total_cost
299        || estimate.plan_rows > config.max_plan_rows
300    {
301        ExplainDecision::Reject {
302            total_cost: estimate.total_cost,
303            plan_rows: estimate.plan_rows,
304            max_cost: config.max_total_cost,
305            max_rows: config.max_plan_rows,
306        }
307    } else {
308        ExplainDecision::Allow
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    #[test]
317    fn test_parse_explain_json() {
318        let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
319        let est = parse_explain_json(json).unwrap();
320        assert!((est.total_cost - 1234.56).abs() < 0.01);
321        assert_eq!(est.plan_rows, 5000);
322    }
323
324    #[test]
325    fn test_parse_explain_json_nested_join() {
326        let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
327        let est = parse_explain_json(json).unwrap();
328        assert!((est.total_cost - 250000.0).abs() < 0.01);
329        assert_eq!(est.plan_rows, 2_000_000);
330    }
331
332    #[test]
333    fn test_parse_explain_json_invalid() {
334        assert!(parse_explain_json("not json").is_none());
335        assert!(parse_explain_json("{}").is_none());
336        assert!(parse_explain_json("[]").is_none());
337        assert!(
338            parse_explain_json(r#"[{"Plan": {"Total Cost": 1e999, "Plan Rows": 5000}}]"#).is_none()
339        );
340        assert!(
341            parse_explain_json(r#"[{"Plan": {"Total Cost": 100.0, "Plan Rows": -1}}]"#).is_none()
342        );
343    }
344
345    #[test]
346    fn test_check_estimate_allow() {
347        let config = ExplainConfig::default();
348        let est = ExplainEstimate {
349            total_cost: 100.0,
350            plan_rows: 500,
351        };
352        let decision = check_estimate(&est, &config);
353        assert!(!decision.is_rejected());
354    }
355
356    #[test]
357    fn test_check_estimate_reject_cost() {
358        let config = ExplainConfig::default();
359        let est = ExplainEstimate {
360            total_cost: 200_000.0,
361            plan_rows: 500,
362        };
363        let decision = check_estimate(&est, &config);
364        assert!(decision.is_rejected());
365        assert!(decision.rejection_message().unwrap().contains("200000"));
366    }
367
368    #[test]
369    fn test_check_estimate_reject_rows() {
370        let config = ExplainConfig::default();
371        let est = ExplainEstimate {
372            total_cost: 50.0,
373            plan_rows: 5_000_000,
374        };
375        let decision = check_estimate(&est, &config);
376        assert!(decision.is_rejected());
377    }
378
379    #[test]
380    fn test_check_estimate_rejects_non_finite_costs() {
381        let config = ExplainConfig {
382            max_total_cost: f64::NAN,
383            ..Default::default()
384        };
385        let est = ExplainEstimate {
386            total_cost: 50.0,
387            plan_rows: 500,
388        };
389        assert!(check_estimate(&est, &config).is_rejected());
390
391        let config = ExplainConfig::default();
392        let est = ExplainEstimate {
393            total_cost: f64::INFINITY,
394            plan_rows: 500,
395        };
396        assert!(check_estimate(&est, &config).is_rejected());
397    }
398
399    #[test]
400    fn test_cache_basic() {
401        let cache = ExplainCache::new(Duration::from_secs(60));
402        assert!(cache.is_empty());
403
404        cache.insert(
405            42,
406            ExplainEstimate {
407                total_cost: 100.0,
408                plan_rows: 50,
409            },
410        );
411        assert_eq!(cache.len(), 1);
412
413        let cached = cache.get(42, None).unwrap();
414        assert!((cached.total_cost - 100.0).abs() < 0.01);
415        assert_eq!(cached.plan_rows, 50);
416
417        // Miss for unknown key
418        assert!(cache.get(99, None).is_none());
419    }
420
421    #[test]
422    fn test_cache_expiry() {
423        let cache = ExplainCache::new(Duration::from_millis(1));
424        cache.insert(
425            1,
426            ExplainEstimate {
427                total_cost: 100.0,
428                plan_rows: 50,
429            },
430        );
431
432        // Wait for expiry
433        std::thread::sleep(Duration::from_millis(5));
434        assert!(cache.get(1, None).is_none());
435    }
436
437    #[test]
438    fn test_cache_drift_invalidation() {
439        let cache = ExplainCache::new(Duration::from_secs(60));
440
441        // ── Small dataset: relative drift alone should NOT invalidate ──
442        cache.insert(
443            1,
444            ExplainEstimate {
445                total_cost: 50.0,
446                plan_rows: 1000,
447            },
448        );
449
450        // No reltuples — pure TTL, should hit
451        assert!(cache.get(1, None).is_some());
452
453        // Same estimate — no drift, should hit
454        assert!(cache.get(1, Some(1000)).is_some());
455
456        // 60% relative drift but only 600 absolute — below 10k floor, should STILL hit
457        assert!(
458            cache.get(1, Some(1600)).is_some(),
459            "small table should not thrash"
460        );
461
462        // 60% shrinkage but only 600 absolute — should STILL hit
463        assert!(
464            cache.get(1, Some(400)).is_some(),
465            "small shrinkage should not thrash"
466        );
467
468        // ── Large dataset: BOTH relative AND absolute thresholds exceeded ──
469        cache.insert(
470            3,
471            ExplainEstimate {
472                total_cost: 500.0,
473                plan_rows: 50_000,
474            },
475        );
476
477        // 70% drift + 35k absolute (both above threshold) — should miss
478        assert!(
479            cache.get(3, Some(85_000)).is_none(),
480            "large drift should invalidate"
481        );
482
483        // 40% drift + 20k absolute (relative below 50%) — should STILL hit
484        assert!(
485            cache.get(3, Some(70_000)).is_some(),
486            "moderate drift should not invalidate"
487        );
488
489        // 60% shrinkage + 30k absolute (both above threshold) — should miss
490        assert!(
491            cache.get(3, Some(20_000)).is_none(),
492            "large shrinkage should invalidate"
493        );
494
495        // Edge: plan_rows = 0 in cache — skip drift check entirely
496        cache.insert(
497            2,
498            ExplainEstimate {
499                total_cost: 10.0,
500                plan_rows: 0,
501            },
502        );
503        assert!(cache.get(2, Some(999_999)).is_some());
504    }
505
506    #[test]
507    fn test_explain_mode_default() {
508        let config = ExplainConfig::default();
509        assert_eq!(config.mode, ExplainMode::Precheck);
510        assert_eq!(config.depth_threshold, 3);
511        assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
512        assert_eq!(config.max_plan_rows, 1_000_000);
513    }
514}