1use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22 pub mode: ExplainMode,
24
25 pub depth_threshold: usize,
28
29 pub max_total_cost: f64,
32
33 pub max_plan_rows: u64,
36
37 pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43 fn default() -> Self {
44 Self {
45 mode: ExplainMode::Precheck,
46 depth_threshold: 3,
47 max_total_cost: 100_000.0,
48 max_plan_rows: 1_000_000,
49 cache_ttl: Duration::from_secs(300),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57 Off,
59 Precheck,
62 Enforce,
65}
66
67#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70 pub total_cost: f64,
72 pub plan_rows: u64,
74}
75
76struct CachedEstimate {
78 estimate: ExplainEstimate,
79 cached_at: Instant,
80 plan_rows: u64,
82}
83
84pub struct ExplainCache {
86 entries: Mutex<HashMap<u64, CachedEstimate>>,
87 ttl: Duration,
88 max_entries: usize,
90}
91
92impl ExplainCache {
93 pub fn new(ttl: Duration) -> Self {
95 Self {
96 entries: Mutex::new(HashMap::new()),
97 ttl,
98 max_entries: 10_000,
99 }
100 }
101
102 pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109 let entries = self.entries.lock().ok()?;
110 let entry = entries.get(&shape_hash)?;
111 if entry.cached_at.elapsed() < self.ttl {
112 if let Some(current) = current_reltuples
116 && entry.plan_rows > 0
117 {
118 let cached = entry.plan_rows as f64;
119 let drift = ((current as f64) - cached).abs() / cached;
120 let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121 if drift > 0.5 && abs_delta > 10_000 {
122 return None; }
124 }
125 Some(entry.estimate.clone())
126 } else {
127 None
128 }
129 }
130
131 pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133 if let Ok(mut entries) = self.entries.lock() {
134 if entries.len() >= self.max_entries / 2 {
136 let ttl = self.ttl;
137 entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138 }
139 if entries.len() >= self.max_entries {
141 return;
142 }
143 entries.insert(shape_hash, CachedEstimate {
144 plan_rows: estimate.plan_rows,
145 estimate,
146 cached_at: Instant::now(),
147 });
148 }
149 }
150
151 pub fn len(&self) -> usize {
153 self.entries.lock().map(|e| e.len()).unwrap_or(0)
154 }
155
156 pub fn is_empty(&self) -> bool {
158 self.len() == 0
159 }
160}
161
162pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
170 let total_cost = extract_json_number(json_str, "Total Cost")?;
171 let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
172
173 Some(ExplainEstimate {
174 total_cost,
175 plan_rows,
176 })
177}
178
179fn extract_json_number(json: &str, key: &str) -> Option<f64> {
181 let pattern = format!("\"{}\":", key);
182 let start = json.find(&pattern)?;
183 let after_key = &json[start + pattern.len()..];
184
185 let trimmed = after_key.trim_start();
187
188 let end = trimmed.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+')?;
190 let num_str = &trimmed[..end];
191 num_str.parse::<f64>().ok()
192}
193
194#[derive(Debug)]
196pub enum ExplainDecision {
197 Allow,
199 Reject {
201 total_cost: f64,
203 plan_rows: u64,
205 max_cost: f64,
207 max_rows: u64,
209 },
210 Skipped,
212}
213
214impl ExplainDecision {
215 pub fn is_rejected(&self) -> bool {
217 matches!(self, ExplainDecision::Reject { .. })
218 }
219
220 pub fn rejection_message(&self) -> Option<String> {
222 match self {
223 ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
224 Some(format!(
225 "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
226 or estimated rows {} exceeds limit {}. \
227 Try narrowing your filters, reducing ?expand depth, or using pagination.",
228 total_cost, max_cost, plan_rows, max_rows
229 ))
230 }
231 _ => None,
232 }
233 }
234
235 pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
240 match self {
241 ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
242 Some(ExplainRejectionDetail {
243 estimated_cost: *total_cost,
244 cost_limit: *max_cost,
245 estimated_rows: *plan_rows,
246 row_limit: *max_rows,
247 suggestions: vec![
248 "Add WHERE clauses to narrow the result set".to_string(),
249 "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
250 "Use ?limit and ?offset for pagination".to_string(),
251 "Add indexes on frequently filtered columns".to_string(),
252 ],
253 })
254 }
255 _ => None,
256 }
257 }
258}
259
260#[derive(Debug, Clone)]
262pub struct ExplainRejectionDetail {
263 pub estimated_cost: f64,
265 pub cost_limit: f64,
267 pub estimated_rows: u64,
269 pub row_limit: u64,
271 pub suggestions: Vec<String>,
273}
274
275pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
277 if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
278 ExplainDecision::Reject {
279 total_cost: estimate.total_cost,
280 plan_rows: estimate.plan_rows,
281 max_cost: config.max_total_cost,
282 max_rows: config.max_plan_rows,
283 }
284 } else {
285 ExplainDecision::Allow
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_parse_explain_json() {
295 let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
296 let est = parse_explain_json(json).unwrap();
297 assert!((est.total_cost - 1234.56).abs() < 0.01);
298 assert_eq!(est.plan_rows, 5000);
299 }
300
301 #[test]
302 fn test_parse_explain_json_nested_join() {
303 let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
304 let est = parse_explain_json(json).unwrap();
305 assert!((est.total_cost - 250000.0).abs() < 0.01);
306 assert_eq!(est.plan_rows, 2_000_000);
307 }
308
309 #[test]
310 fn test_parse_explain_json_invalid() {
311 assert!(parse_explain_json("not json").is_none());
312 assert!(parse_explain_json("{}").is_none());
313 assert!(parse_explain_json("[]").is_none());
314 }
315
316 #[test]
317 fn test_check_estimate_allow() {
318 let config = ExplainConfig::default();
319 let est = ExplainEstimate { total_cost: 100.0, plan_rows: 500 };
320 let decision = check_estimate(&est, &config);
321 assert!(!decision.is_rejected());
322 }
323
324 #[test]
325 fn test_check_estimate_reject_cost() {
326 let config = ExplainConfig::default();
327 let est = ExplainEstimate { total_cost: 200_000.0, plan_rows: 500 };
328 let decision = check_estimate(&est, &config);
329 assert!(decision.is_rejected());
330 assert!(decision.rejection_message().unwrap().contains("200000"));
331 }
332
333 #[test]
334 fn test_check_estimate_reject_rows() {
335 let config = ExplainConfig::default();
336 let est = ExplainEstimate { total_cost: 50.0, plan_rows: 5_000_000 };
337 let decision = check_estimate(&est, &config);
338 assert!(decision.is_rejected());
339 }
340
341 #[test]
342 fn test_cache_basic() {
343 let cache = ExplainCache::new(Duration::from_secs(60));
344 assert!(cache.is_empty());
345
346 cache.insert(42, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
347 assert_eq!(cache.len(), 1);
348
349 let cached = cache.get(42, None).unwrap();
350 assert!((cached.total_cost - 100.0).abs() < 0.01);
351 assert_eq!(cached.plan_rows, 50);
352
353 assert!(cache.get(99, None).is_none());
355 }
356
357 #[test]
358 fn test_cache_expiry() {
359 let cache = ExplainCache::new(Duration::from_millis(1));
360 cache.insert(1, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
361
362 std::thread::sleep(Duration::from_millis(5));
364 assert!(cache.get(1, None).is_none());
365 }
366
367 #[test]
368 fn test_cache_drift_invalidation() {
369 let cache = ExplainCache::new(Duration::from_secs(60));
370
371 cache.insert(1, ExplainEstimate { total_cost: 50.0, plan_rows: 1000 });
373
374 assert!(cache.get(1, None).is_some());
376
377 assert!(cache.get(1, Some(1000)).is_some());
379
380 assert!(cache.get(1, Some(1600)).is_some(), "small table should not thrash");
382
383 assert!(cache.get(1, Some(400)).is_some(), "small shrinkage should not thrash");
385
386 cache.insert(3, ExplainEstimate { total_cost: 500.0, plan_rows: 50_000 });
388
389 assert!(cache.get(3, Some(85_000)).is_none(), "large drift should invalidate");
391
392 assert!(cache.get(3, Some(70_000)).is_some(), "moderate drift should not invalidate");
394
395 assert!(cache.get(3, Some(20_000)).is_none(), "large shrinkage should invalidate");
397
398 cache.insert(2, ExplainEstimate { total_cost: 10.0, plan_rows: 0 });
400 assert!(cache.get(2, Some(999_999)).is_some());
401 }
402
403 #[test]
404 fn test_explain_mode_default() {
405 let config = ExplainConfig::default();
406 assert_eq!(config.mode, ExplainMode::Precheck);
407 assert_eq!(config.depth_threshold, 3);
408 assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
409 assert_eq!(config.max_plan_rows, 1_000_000);
410 }
411}