Skip to main content

qail_pg/driver/
explain.rs

1//! EXPLAIN-based query cost estimation for pre-check rejection.
2//!
3//! Provides runtime cost-based rejection of queries that would be too
4//! expensive, using PostgreSQL's `EXPLAIN (FORMAT JSON)` output.
5//!
6//! # Modes
7//! - **Off**: No EXPLAIN pre-check
8//! - **Precheck**: Run EXPLAIN on cache-miss for queries with expand depth ≥ threshold
9//! - **Enforce**: Always run EXPLAIN and enforce cost thresholds
10//!
11//! # Caching
12//! EXPLAIN results are cached by `AST_shape_hash + rls_signature` with configurable TTL.
13//! This avoids repeated EXPLAIN calls for the same query shape.
14
15use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19/// Configuration for EXPLAIN pre-check behavior.
20#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22    /// Operating mode for EXPLAIN pre-check.
23    pub mode: ExplainMode,
24
25    /// Run EXPLAIN for queries with expand_depth >= this value.
26    /// Default: 3 (queries joining 3+ tables get pre-checked).
27    pub depth_threshold: usize,
28
29    /// Reject if PostgreSQL's estimated total cost exceeds this.
30    /// Default: 100,000 (unitless PostgreSQL planner cost).
31    pub max_total_cost: f64,
32
33    /// Reject if PostgreSQL estimates more rows than this.
34    /// Default: 1,000,000 rows.
35    pub max_plan_rows: u64,
36
37    /// TTL for cached EXPLAIN results.
38    /// Default: 5 minutes.
39    pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43    fn default() -> Self {
44        Self {
45            mode: ExplainMode::Precheck,
46            depth_threshold: 3,
47            max_total_cost: 100_000.0,
48            max_plan_rows: 1_000_000,
49            cache_ttl: Duration::from_secs(300),
50        }
51    }
52}
53
54/// Operating mode for EXPLAIN pre-check.
55#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57    /// No EXPLAIN pre-check — fastest, no protection.
58    Off,
59    /// Run EXPLAIN on cache-miss for queries above depth threshold.
60    /// Recommended default for production.
61    Precheck,
62    /// Always run EXPLAIN and enforce — strictest, slight latency cost.
63    /// Recommended for staging or high-security tenants.
64    Enforce,
65}
66
67/// Result of an EXPLAIN pre-check.
68#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70    /// PostgreSQL's estimated total cost (arbitrary units).
71    pub total_cost: f64,
72    /// PostgreSQL's estimated number of rows returned.
73    pub plan_rows: u64,
74}
75
76/// Cached EXPLAIN result with TTL and row-estimate snapshot.
77struct CachedEstimate {
78    estimate: ExplainEstimate,
79    cached_at: Instant,
80    /// Row estimate snapshot at cache time, for drift detection.
81    plan_rows: u64,
82}
83
84/// In-memory cache for EXPLAIN estimates, keyed by AST shape hash.
85pub struct ExplainCache {
86    entries: Mutex<HashMap<u64, CachedEstimate>>,
87    ttl: Duration,
88    /// Maximum number of cached entries to prevent OOM from shape explosion
89    max_entries: usize,
90}
91
92impl ExplainCache {
93    /// Create a new EXPLAIN cache with the given TTL.
94    pub fn new(ttl: Duration) -> Self {
95        Self {
96            entries: Mutex::new(HashMap::new()),
97            ttl,
98            max_entries: 10_000,
99        }
100    }
101
102    /// Get a cached estimate if it exists, hasn't expired, and row-estimate
103    /// hasn't drifted beyond 50%.
104    ///
105    /// `current_reltuples` is the current `pg_class.reltuples` for the primary
106    /// table. If provided and the cached plan_rows have drifted >50% from
107    /// the current estimate, the entry is considered stale (data skew).
108    pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109        let entries = self.entries.lock().ok()?;
110        let entry = entries.get(&shape_hash)?;
111        if entry.cached_at.elapsed() < self.ttl {
112            // Row-estimate drift check: invalidate if BOTH conditions met:
113            // 1. Relative change > 50% (data skew)
114            // 2. Absolute delta > 10,000 rows (prevents small table thrash)
115            if let Some(current) = current_reltuples
116                && entry.plan_rows > 0
117            {
118                let cached = entry.plan_rows as f64;
119                let drift = ((current as f64) - cached).abs() / cached;
120                let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121                if drift > 0.5 && abs_delta > 10_000 {
122                    return None; // Stale — significant data skew detected
123                }
124            }
125            Some(entry.estimate.clone())
126        } else {
127            None
128        }
129    }
130
131    /// Store an estimate in the cache.
132    pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133        if let Ok(mut entries) = self.entries.lock() {
134            // Evict expired entries when approaching capacity
135            if entries.len() >= self.max_entries / 2 {
136                let ttl = self.ttl;
137                entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138            }
139            // Hard cap: if still at capacity after eviction, skip insert
140            if entries.len() >= self.max_entries {
141                return;
142            }
143            entries.insert(shape_hash, CachedEstimate {
144                plan_rows: estimate.plan_rows,
145                estimate,
146                cached_at: Instant::now(),
147            });
148        }
149    }
150
151    /// Number of cached entries (for metrics).
152    pub fn len(&self) -> usize {
153        self.entries.lock().map(|e| e.len()).unwrap_or(0)
154    }
155
156    /// Whether the cache is empty.
157    pub fn is_empty(&self) -> bool {
158        self.len() == 0
159    }
160}
161
162/// Parse `EXPLAIN (FORMAT JSON)` output to extract cost estimates.
163///
164/// Uses lightweight string parsing to avoid adding serde_json as a
165/// dependency to the pg driver crate. The EXPLAIN JSON format is stable:
166/// ```json
167/// [{"Plan": {"Total Cost": 1234.56, "Plan Rows": 5000, ...}}]
168/// ```
169pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
170    let total_cost = extract_json_number(json_str, "Total Cost")?;
171    let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
172
173    Some(ExplainEstimate {
174        total_cost,
175        plan_rows,
176    })
177}
178
179/// Extract a numeric value after `"key":` from a JSON string.
180fn extract_json_number(json: &str, key: &str) -> Option<f64> {
181    let pattern = format!("\"{}\":", key);
182    let start = json.find(&pattern)?;
183    let after_key = &json[start + pattern.len()..];
184
185    // Skip whitespace
186    let trimmed = after_key.trim_start();
187
188    // Parse the number (may be integer or float)
189    let end = trimmed.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+')?;
190    let num_str = &trimmed[..end];
191    num_str.parse::<f64>().ok()
192}
193
194/// Decision from the EXPLAIN pre-check.
195#[derive(Debug)]
196pub enum ExplainDecision {
197    /// Query is allowed to proceed.
198    Allow,
199    /// Query is rejected with an explanation.
200    Reject {
201        /// PostgreSQL's estimated total cost for the query.
202        total_cost: f64,
203        /// PostgreSQL's estimated row count.
204        plan_rows: u64,
205        /// Configured maximum cost threshold.
206        max_cost: f64,
207        /// Configured maximum row threshold.
208        max_rows: u64,
209    },
210    /// EXPLAIN was skipped (mode=Off or below depth threshold).
211    Skipped,
212}
213
214impl ExplainDecision {
215    /// Returns true if the query should be rejected.
216    pub fn is_rejected(&self) -> bool {
217        matches!(self, ExplainDecision::Reject { .. })
218    }
219
220    /// Human-readable rejection message for API responses.
221    pub fn rejection_message(&self) -> Option<String> {
222        match self {
223            ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
224                Some(format!(
225                    "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
226                     or estimated rows {} exceeds limit {}. \
227                     Try narrowing your filters, reducing ?expand depth, or using pagination.",
228                    total_cost, max_cost, plan_rows, max_rows
229                ))
230            }
231            _ => None,
232        }
233    }
234
235    /// Machine-readable rejection detail for structured API error responses.
236    ///
237    /// Returns `None` for `Allow` and `Skipped` decisions.
238    /// Client SDKs can use this to programmatically react to cost rejections.
239    pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
240        match self {
241            ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
242                Some(ExplainRejectionDetail {
243                    estimated_cost: *total_cost,
244                    cost_limit: *max_cost,
245                    estimated_rows: *plan_rows,
246                    row_limit: *max_rows,
247                    suggestions: vec![
248                        "Add WHERE clauses to narrow the result set".to_string(),
249                        "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
250                        "Use ?limit and ?offset for pagination".to_string(),
251                        "Add indexes on frequently filtered columns".to_string(),
252                    ],
253                })
254            }
255            _ => None,
256        }
257    }
258}
259
260/// Structured rejection detail for EXPLAIN cost guard violations.
261#[derive(Debug, Clone)]
262pub struct ExplainRejectionDetail {
263    /// PostgreSQL's estimated total cost for the query.
264    pub estimated_cost: f64,
265    /// Configured maximum cost threshold.
266    pub cost_limit: f64,
267    /// PostgreSQL's estimated row count.
268    pub estimated_rows: u64,
269    /// Configured maximum row threshold.
270    pub row_limit: u64,
271    /// Actionable suggestions to bring the query under limits.
272    pub suggestions: Vec<String>,
273}
274
275/// Check an EXPLAIN estimate against configured thresholds.
276pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
277    if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
278        ExplainDecision::Reject {
279            total_cost: estimate.total_cost,
280            plan_rows: estimate.plan_rows,
281            max_cost: config.max_total_cost,
282            max_rows: config.max_plan_rows,
283        }
284    } else {
285        ExplainDecision::Allow
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_parse_explain_json() {
295        let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
296        let est = parse_explain_json(json).unwrap();
297        assert!((est.total_cost - 1234.56).abs() < 0.01);
298        assert_eq!(est.plan_rows, 5000);
299    }
300
301    #[test]
302    fn test_parse_explain_json_nested_join() {
303        let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
304        let est = parse_explain_json(json).unwrap();
305        assert!((est.total_cost - 250000.0).abs() < 0.01);
306        assert_eq!(est.plan_rows, 2_000_000);
307    }
308
309    #[test]
310    fn test_parse_explain_json_invalid() {
311        assert!(parse_explain_json("not json").is_none());
312        assert!(parse_explain_json("{}").is_none());
313        assert!(parse_explain_json("[]").is_none());
314    }
315
316    #[test]
317    fn test_check_estimate_allow() {
318        let config = ExplainConfig::default();
319        let est = ExplainEstimate { total_cost: 100.0, plan_rows: 500 };
320        let decision = check_estimate(&est, &config);
321        assert!(!decision.is_rejected());
322    }
323
324    #[test]
325    fn test_check_estimate_reject_cost() {
326        let config = ExplainConfig::default();
327        let est = ExplainEstimate { total_cost: 200_000.0, plan_rows: 500 };
328        let decision = check_estimate(&est, &config);
329        assert!(decision.is_rejected());
330        assert!(decision.rejection_message().unwrap().contains("200000"));
331    }
332
333    #[test]
334    fn test_check_estimate_reject_rows() {
335        let config = ExplainConfig::default();
336        let est = ExplainEstimate { total_cost: 50.0, plan_rows: 5_000_000 };
337        let decision = check_estimate(&est, &config);
338        assert!(decision.is_rejected());
339    }
340
341    #[test]
342    fn test_cache_basic() {
343        let cache = ExplainCache::new(Duration::from_secs(60));
344        assert!(cache.is_empty());
345
346        cache.insert(42, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
347        assert_eq!(cache.len(), 1);
348
349        let cached = cache.get(42, None).unwrap();
350        assert!((cached.total_cost - 100.0).abs() < 0.01);
351        assert_eq!(cached.plan_rows, 50);
352
353        // Miss for unknown key
354        assert!(cache.get(99, None).is_none());
355    }
356
357    #[test]
358    fn test_cache_expiry() {
359        let cache = ExplainCache::new(Duration::from_millis(1));
360        cache.insert(1, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
361
362        // Wait for expiry
363        std::thread::sleep(Duration::from_millis(5));
364        assert!(cache.get(1, None).is_none());
365    }
366
367    #[test]
368    fn test_cache_drift_invalidation() {
369        let cache = ExplainCache::new(Duration::from_secs(60));
370
371        // ── Small dataset: relative drift alone should NOT invalidate ──
372        cache.insert(1, ExplainEstimate { total_cost: 50.0, plan_rows: 1000 });
373
374        // No reltuples — pure TTL, should hit
375        assert!(cache.get(1, None).is_some());
376
377        // Same estimate — no drift, should hit
378        assert!(cache.get(1, Some(1000)).is_some());
379
380        // 60% relative drift but only 600 absolute — below 10k floor, should STILL hit
381        assert!(cache.get(1, Some(1600)).is_some(), "small table should not thrash");
382
383        // 60% shrinkage but only 600 absolute — should STILL hit
384        assert!(cache.get(1, Some(400)).is_some(), "small shrinkage should not thrash");
385
386        // ── Large dataset: BOTH relative AND absolute thresholds exceeded ──
387        cache.insert(3, ExplainEstimate { total_cost: 500.0, plan_rows: 50_000 });
388
389        // 70% drift + 35k absolute (both above threshold) — should miss
390        assert!(cache.get(3, Some(85_000)).is_none(), "large drift should invalidate");
391
392        // 40% drift + 20k absolute (relative below 50%) — should STILL hit
393        assert!(cache.get(3, Some(70_000)).is_some(), "moderate drift should not invalidate");
394
395        // 60% shrinkage + 30k absolute (both above threshold) — should miss
396        assert!(cache.get(3, Some(20_000)).is_none(), "large shrinkage should invalidate");
397
398        // Edge: plan_rows = 0 in cache — skip drift check entirely
399        cache.insert(2, ExplainEstimate { total_cost: 10.0, plan_rows: 0 });
400        assert!(cache.get(2, Some(999_999)).is_some());
401    }
402
403    #[test]
404    fn test_explain_mode_default() {
405        let config = ExplainConfig::default();
406        assert_eq!(config.mode, ExplainMode::Precheck);
407        assert_eq!(config.depth_threshold, 3);
408        assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
409        assert_eq!(config.max_plan_rows, 1_000_000);
410    }
411}