Skip to main content

qail_pg/driver/
explain.rs

1//! EXPLAIN-based query cost estimation for pre-check rejection.
2//!
3//! Provides runtime cost-based rejection of queries that would be too
4//! expensive, using PostgreSQL's `EXPLAIN (FORMAT JSON)` output.
5//!
6//! # Modes
7//! - **Off**: No EXPLAIN pre-check
8//! - **Precheck**: Run EXPLAIN on cache-miss for queries with expand depth ≥ threshold
9//! - **Enforce**: Always run EXPLAIN and enforce cost thresholds
10//!
11//! # Caching
12//! EXPLAIN results are cached by `AST_shape_hash + rls_signature` with configurable TTL.
13//! This avoids repeated EXPLAIN calls for the same query shape.
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19/// Configuration for EXPLAIN pre-check behavior.
20#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22    /// Operating mode for EXPLAIN pre-check.
23    pub mode: ExplainMode,
24
25    /// Run EXPLAIN for queries with expand_depth >= this value.
26    /// Default: 3 (queries joining 3+ tables get pre-checked).
27    pub depth_threshold: usize,
28
29    /// Reject if PostgreSQL's estimated total cost exceeds this.
30    /// Default: 100,000 (unitless PostgreSQL planner cost).
31    pub max_total_cost: f64,
32
33    /// Reject if PostgreSQL estimates more rows than this.
34    /// Default: 1,000,000 rows.
35    pub max_plan_rows: u64,
36
37    /// TTL for cached EXPLAIN results.
38    /// Default: 5 minutes.
39    pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43    fn default() -> Self {
44        Self {
45            mode: ExplainMode::Precheck,
46            depth_threshold: 3,
47            max_total_cost: 100_000.0,
48            max_plan_rows: 1_000_000,
49            cache_ttl: Duration::from_secs(300),
50        }
51    }
52}
53
54/// Operating mode for EXPLAIN pre-check.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57    /// No EXPLAIN pre-check — fastest, no protection.
58    Off,
59    /// Run EXPLAIN on cache-miss for queries above depth threshold.
60    /// Recommended default for production.
61    Precheck,
62    /// Always run EXPLAIN and enforce — strictest, slight latency cost.
63    /// Recommended for staging or high-security tenants.
64    Enforce,
65}
66
67/// Result of an EXPLAIN pre-check.
68#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70    /// PostgreSQL's estimated total cost (arbitrary units).
71    pub total_cost: f64,
72    /// PostgreSQL's estimated number of rows returned.
73    pub plan_rows: u64,
74}
75
76/// Cached EXPLAIN result with TTL and row-estimate snapshot.
77struct CachedEstimate {
78    estimate: ExplainEstimate,
79    cached_at: Instant,
80    /// Row estimate snapshot at cache time, for drift detection.
81    plan_rows: u64,
82}
83
84/// In-memory cache for EXPLAIN estimates, keyed by AST shape hash.
85pub struct ExplainCache {
86    entries: Mutex<HashMap<u64, CachedEstimate>>,
87    ttl: Duration,
88    /// Maximum number of cached entries to prevent OOM from shape explosion
89    max_entries: usize,
90}
91
92impl ExplainCache {
93    /// Create a new EXPLAIN cache with the given TTL.
94    pub fn new(ttl: Duration) -> Self {
95        Self {
96            entries: Mutex::new(HashMap::new()),
97            ttl,
98            max_entries: 10_000,
99        }
100    }
101
102    /// Get a cached estimate if it exists, hasn't expired, and row-estimate
103    /// hasn't drifted beyond 50%.
104    ///
105    /// `current_reltuples` is the current `pg_class.reltuples` for the primary
106    /// table. If provided and the cached plan_rows have drifted >50% from
107    /// the current estimate, the entry is considered stale (data skew).
108    pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109        let entries = self.entries.lock().ok()?;
110        let entry = entries.get(&shape_hash)?;
111        if entry.cached_at.elapsed() < self.ttl {
112            // Row-estimate drift check: invalidate if BOTH conditions met:
113            // 1. Relative change > 50% (data skew)
114            // 2. Absolute delta > 10,000 rows (prevents small table thrash)
115            if let Some(current) = current_reltuples
116                && entry.plan_rows > 0
117            {
118                let cached = entry.plan_rows as f64;
119                let drift = ((current as f64) - cached).abs() / cached;
120                let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121                if drift > 0.5 && abs_delta > 10_000 {
122                    return None; // Stale — significant data skew detected
123                }
124            }
125            Some(entry.estimate.clone())
126        } else {
127            None
128        }
129    }
130
131    /// Store an estimate in the cache.
132    pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133        if let Ok(mut entries) = self.entries.lock() {
134            // Evict expired entries when approaching capacity
135            if entries.len() >= self.max_entries / 2 {
136                let ttl = self.ttl;
137                entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138            }
139            // Hard cap: if still at capacity after eviction, skip insert
140            if entries.len() >= self.max_entries {
141                return;
142            }
143            entries.insert(
144                shape_hash,
145                CachedEstimate {
146                    plan_rows: estimate.plan_rows,
147                    estimate,
148                    cached_at: Instant::now(),
149                },
150            );
151        }
152    }
153
154    /// Number of cached entries (for metrics).
155    pub fn len(&self) -> usize {
156        self.entries.lock().map(|e| e.len()).unwrap_or(0)
157    }
158
159    /// Whether the cache is empty.
160    pub fn is_empty(&self) -> bool {
161        self.len() == 0
162    }
163}
164
165/// Parse `EXPLAIN (FORMAT JSON)` output to extract cost estimates.
166///
167/// Uses lightweight string parsing to avoid adding serde_json as a
168/// dependency to the pg driver crate. The EXPLAIN JSON format is stable:
169/// ```json
170/// [{"Plan": {"Total Cost": 1234.56, "Plan Rows": 5000, ...}}]
171/// ```
172pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
173    let total_cost = extract_json_number(json_str, "Total Cost")?;
174    let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
175
176    Some(ExplainEstimate {
177        total_cost,
178        plan_rows,
179    })
180}
181
182/// Extract a numeric value after `"key":` from a JSON string.
183fn extract_json_number(json: &str, key: &str) -> Option<f64> {
184    let pattern = format!("\"{}\":", key);
185    let start = json.find(&pattern)?;
186    let after_key = &json[start + pattern.len()..];
187
188    // Skip whitespace
189    let trimmed = after_key.trim_start();
190
191    // Parse the number (may be integer or float)
192    let end = trimmed.find(|c: char| {
193        !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
194    })?;
195    let num_str = &trimmed[..end];
196    num_str.parse::<f64>().ok()
197}
198
199/// Decision from the EXPLAIN pre-check.
200#[derive(Debug)]
201pub enum ExplainDecision {
202    /// Query is allowed to proceed.
203    Allow,
204    /// Query is rejected with an explanation.
205    Reject {
206        /// PostgreSQL's estimated total cost for the query.
207        total_cost: f64,
208        /// PostgreSQL's estimated row count.
209        plan_rows: u64,
210        /// Configured maximum cost threshold.
211        max_cost: f64,
212        /// Configured maximum row threshold.
213        max_rows: u64,
214    },
215    /// EXPLAIN was skipped (mode=Off or below depth threshold).
216    Skipped,
217}
218
219impl ExplainDecision {
220    /// Returns true if the query should be rejected.
221    pub fn is_rejected(&self) -> bool {
222        matches!(self, ExplainDecision::Reject { .. })
223    }
224
225    /// Human-readable rejection message for API responses.
226    pub fn rejection_message(&self) -> Option<String> {
227        match self {
228            ExplainDecision::Reject {
229                total_cost,
230                plan_rows,
231                max_cost,
232                max_rows,
233            } => Some(format!(
234                "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
235                     or estimated rows {} exceeds limit {}. \
236                     Try narrowing your filters, reducing ?expand depth, or using pagination.",
237                total_cost, max_cost, plan_rows, max_rows
238            )),
239            _ => None,
240        }
241    }
242
243    /// Machine-readable rejection detail for structured API error responses.
244    ///
245    /// Returns `None` for `Allow` and `Skipped` decisions.
246    /// Client SDKs can use this to programmatically react to cost rejections.
247    pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
248        match self {
249            ExplainDecision::Reject {
250                total_cost,
251                plan_rows,
252                max_cost,
253                max_rows,
254            } => Some(ExplainRejectionDetail {
255                estimated_cost: *total_cost,
256                cost_limit: *max_cost,
257                estimated_rows: *plan_rows,
258                row_limit: *max_rows,
259                suggestions: vec![
260                    "Add WHERE clauses to narrow the result set".to_string(),
261                    "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
262                    "Use ?limit and ?offset for pagination".to_string(),
263                    "Add indexes on frequently filtered columns".to_string(),
264                ],
265            }),
266            _ => None,
267        }
268    }
269}
270
271/// Structured rejection detail for EXPLAIN cost guard violations.
272#[derive(Debug, Clone)]
273pub struct ExplainRejectionDetail {
274    /// PostgreSQL's estimated total cost for the query.
275    pub estimated_cost: f64,
276    /// Configured maximum cost threshold.
277    pub cost_limit: f64,
278    /// PostgreSQL's estimated row count.
279    pub estimated_rows: u64,
280    /// Configured maximum row threshold.
281    pub row_limit: u64,
282    /// Actionable suggestions to bring the query under limits.
283    pub suggestions: Vec<String>,
284}
285
286/// Check an EXPLAIN estimate against configured thresholds.
287pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
288    if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
289        ExplainDecision::Reject {
290            total_cost: estimate.total_cost,
291            plan_rows: estimate.plan_rows,
292            max_cost: config.max_total_cost,
293            max_rows: config.max_plan_rows,
294        }
295    } else {
296        ExplainDecision::Allow
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn test_parse_explain_json() {
306        let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
307        let est = parse_explain_json(json).unwrap();
308        assert!((est.total_cost - 1234.56).abs() < 0.01);
309        assert_eq!(est.plan_rows, 5000);
310    }
311
312    #[test]
313    fn test_parse_explain_json_nested_join() {
314        let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
315        let est = parse_explain_json(json).unwrap();
316        assert!((est.total_cost - 250000.0).abs() < 0.01);
317        assert_eq!(est.plan_rows, 2_000_000);
318    }
319
320    #[test]
321    fn test_parse_explain_json_invalid() {
322        assert!(parse_explain_json("not json").is_none());
323        assert!(parse_explain_json("{}").is_none());
324        assert!(parse_explain_json("[]").is_none());
325    }
326
327    #[test]
328    fn test_check_estimate_allow() {
329        let config = ExplainConfig::default();
330        let est = ExplainEstimate {
331            total_cost: 100.0,
332            plan_rows: 500,
333        };
334        let decision = check_estimate(&est, &config);
335        assert!(!decision.is_rejected());
336    }
337
338    #[test]
339    fn test_check_estimate_reject_cost() {
340        let config = ExplainConfig::default();
341        let est = ExplainEstimate {
342            total_cost: 200_000.0,
343            plan_rows: 500,
344        };
345        let decision = check_estimate(&est, &config);
346        assert!(decision.is_rejected());
347        assert!(decision.rejection_message().unwrap().contains("200000"));
348    }
349
350    #[test]
351    fn test_check_estimate_reject_rows() {
352        let config = ExplainConfig::default();
353        let est = ExplainEstimate {
354            total_cost: 50.0,
355            plan_rows: 5_000_000,
356        };
357        let decision = check_estimate(&est, &config);
358        assert!(decision.is_rejected());
359    }
360
361    #[test]
362    fn test_cache_basic() {
363        let cache = ExplainCache::new(Duration::from_secs(60));
364        assert!(cache.is_empty());
365
366        cache.insert(
367            42,
368            ExplainEstimate {
369                total_cost: 100.0,
370                plan_rows: 50,
371            },
372        );
373        assert_eq!(cache.len(), 1);
374
375        let cached = cache.get(42, None).unwrap();
376        assert!((cached.total_cost - 100.0).abs() < 0.01);
377        assert_eq!(cached.plan_rows, 50);
378
379        // Miss for unknown key
380        assert!(cache.get(99, None).is_none());
381    }
382
383    #[test]
384    fn test_cache_expiry() {
385        let cache = ExplainCache::new(Duration::from_millis(1));
386        cache.insert(
387            1,
388            ExplainEstimate {
389                total_cost: 100.0,
390                plan_rows: 50,
391            },
392        );
393
394        // Wait for expiry
395        std::thread::sleep(Duration::from_millis(5));
396        assert!(cache.get(1, None).is_none());
397    }
398
399    #[test]
400    fn test_cache_drift_invalidation() {
401        let cache = ExplainCache::new(Duration::from_secs(60));
402
403        // ── Small dataset: relative drift alone should NOT invalidate ──
404        cache.insert(
405            1,
406            ExplainEstimate {
407                total_cost: 50.0,
408                plan_rows: 1000,
409            },
410        );
411
412        // No reltuples — pure TTL, should hit
413        assert!(cache.get(1, None).is_some());
414
415        // Same estimate — no drift, should hit
416        assert!(cache.get(1, Some(1000)).is_some());
417
418        // 60% relative drift but only 600 absolute — below 10k floor, should STILL hit
419        assert!(
420            cache.get(1, Some(1600)).is_some(),
421            "small table should not thrash"
422        );
423
424        // 60% shrinkage but only 600 absolute — should STILL hit
425        assert!(
426            cache.get(1, Some(400)).is_some(),
427            "small shrinkage should not thrash"
428        );
429
430        // ── Large dataset: BOTH relative AND absolute thresholds exceeded ──
431        cache.insert(
432            3,
433            ExplainEstimate {
434                total_cost: 500.0,
435                plan_rows: 50_000,
436            },
437        );
438
439        // 70% drift + 35k absolute (both above threshold) — should miss
440        assert!(
441            cache.get(3, Some(85_000)).is_none(),
442            "large drift should invalidate"
443        );
444
445        // 40% drift + 20k absolute (relative below 50%) — should STILL hit
446        assert!(
447            cache.get(3, Some(70_000)).is_some(),
448            "moderate drift should not invalidate"
449        );
450
451        // 60% shrinkage + 30k absolute (both above threshold) — should miss
452        assert!(
453            cache.get(3, Some(20_000)).is_none(),
454            "large shrinkage should invalidate"
455        );
456
457        // Edge: plan_rows = 0 in cache — skip drift check entirely
458        cache.insert(
459            2,
460            ExplainEstimate {
461                total_cost: 10.0,
462                plan_rows: 0,
463            },
464        );
465        assert!(cache.get(2, Some(999_999)).is_some());
466    }
467
468    #[test]
469    fn test_explain_mode_default() {
470        let config = ExplainConfig::default();
471        assert_eq!(config.mode, ExplainMode::Precheck);
472        assert_eq!(config.depth_threshold, 3);
473        assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
474        assert_eq!(config.max_plan_rows, 1_000_000);
475    }
476}