1use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22 pub mode: ExplainMode,
24
25 pub depth_threshold: usize,
28
29 pub max_total_cost: f64,
32
33 pub max_plan_rows: u64,
36
37 pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43 fn default() -> Self {
44 Self {
45 mode: ExplainMode::Precheck,
46 depth_threshold: 3,
47 max_total_cost: 100_000.0,
48 max_plan_rows: 1_000_000,
49 cache_ttl: Duration::from_secs(300),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57 Off,
59 Precheck,
62 Enforce,
65}
66
67#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70 pub total_cost: f64,
72 pub plan_rows: u64,
74}
75
76struct CachedEstimate {
78 estimate: ExplainEstimate,
79 cached_at: Instant,
80 plan_rows: u64,
82}
83
84pub struct ExplainCache {
86 entries: Mutex<HashMap<u64, CachedEstimate>>,
87 ttl: Duration,
88 max_entries: usize,
90}
91
92impl ExplainCache {
93 pub fn new(ttl: Duration) -> Self {
95 Self {
96 entries: Mutex::new(HashMap::new()),
97 ttl,
98 max_entries: 10_000,
99 }
100 }
101
102 pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109 let entries = self.entries.lock().ok()?;
110 let entry = entries.get(&shape_hash)?;
111 if entry.cached_at.elapsed() < self.ttl {
112 if let Some(current) = current_reltuples
116 && entry.plan_rows > 0
117 {
118 let cached = entry.plan_rows as f64;
119 let drift = ((current as f64) - cached).abs() / cached;
120 let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121 if drift > 0.5 && abs_delta > 10_000 {
122 return None; }
124 }
125 Some(entry.estimate.clone())
126 } else {
127 None
128 }
129 }
130
131 pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133 if let Ok(mut entries) = self.entries.lock() {
134 if entries.len() >= self.max_entries / 2 {
136 let ttl = self.ttl;
137 entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138 }
139 if entries.len() >= self.max_entries {
141 return;
142 }
143 entries.insert(
144 shape_hash,
145 CachedEstimate {
146 plan_rows: estimate.plan_rows,
147 estimate,
148 cached_at: Instant::now(),
149 },
150 );
151 }
152 }
153
154 pub fn len(&self) -> usize {
156 self.entries.lock().map(|e| e.len()).unwrap_or(0)
157 }
158
159 pub fn is_empty(&self) -> bool {
161 self.len() == 0
162 }
163}
164
165pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
173 let total_cost = extract_json_number(json_str, "Total Cost")?;
174 let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
175
176 Some(ExplainEstimate {
177 total_cost,
178 plan_rows,
179 })
180}
181
182fn extract_json_number(json: &str, key: &str) -> Option<f64> {
184 let pattern = format!("\"{}\":", key);
185 let start = json.find(&pattern)?;
186 let after_key = &json[start + pattern.len()..];
187
188 let trimmed = after_key.trim_start();
190
191 let end = trimmed.find(|c: char| {
193 !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+'
194 })?;
195 let num_str = &trimmed[..end];
196 num_str.parse::<f64>().ok()
197}
198
199#[derive(Debug)]
201pub enum ExplainDecision {
202 Allow,
204 Reject {
206 total_cost: f64,
208 plan_rows: u64,
210 max_cost: f64,
212 max_rows: u64,
214 },
215 Skipped,
217}
218
219impl ExplainDecision {
220 pub fn is_rejected(&self) -> bool {
222 matches!(self, ExplainDecision::Reject { .. })
223 }
224
225 pub fn rejection_message(&self) -> Option<String> {
227 match self {
228 ExplainDecision::Reject {
229 total_cost,
230 plan_rows,
231 max_cost,
232 max_rows,
233 } => Some(format!(
234 "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
235 or estimated rows {} exceeds limit {}. \
236 Try narrowing your filters, reducing ?expand depth, or using pagination.",
237 total_cost, max_cost, plan_rows, max_rows
238 )),
239 _ => None,
240 }
241 }
242
243 pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
248 match self {
249 ExplainDecision::Reject {
250 total_cost,
251 plan_rows,
252 max_cost,
253 max_rows,
254 } => Some(ExplainRejectionDetail {
255 estimated_cost: *total_cost,
256 cost_limit: *max_cost,
257 estimated_rows: *plan_rows,
258 row_limit: *max_rows,
259 suggestions: vec![
260 "Add WHERE clauses to narrow the result set".to_string(),
261 "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
262 "Use ?limit and ?offset for pagination".to_string(),
263 "Add indexes on frequently filtered columns".to_string(),
264 ],
265 }),
266 _ => None,
267 }
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct ExplainRejectionDetail {
274 pub estimated_cost: f64,
276 pub cost_limit: f64,
278 pub estimated_rows: u64,
280 pub row_limit: u64,
282 pub suggestions: Vec<String>,
284}
285
286pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
288 if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
289 ExplainDecision::Reject {
290 total_cost: estimate.total_cost,
291 plan_rows: estimate.plan_rows,
292 max_cost: config.max_total_cost,
293 max_rows: config.max_plan_rows,
294 }
295 } else {
296 ExplainDecision::Allow
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_parse_explain_json() {
306 let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
307 let est = parse_explain_json(json).unwrap();
308 assert!((est.total_cost - 1234.56).abs() < 0.01);
309 assert_eq!(est.plan_rows, 5000);
310 }
311
312 #[test]
313 fn test_parse_explain_json_nested_join() {
314 let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
315 let est = parse_explain_json(json).unwrap();
316 assert!((est.total_cost - 250000.0).abs() < 0.01);
317 assert_eq!(est.plan_rows, 2_000_000);
318 }
319
320 #[test]
321 fn test_parse_explain_json_invalid() {
322 assert!(parse_explain_json("not json").is_none());
323 assert!(parse_explain_json("{}").is_none());
324 assert!(parse_explain_json("[]").is_none());
325 }
326
327 #[test]
328 fn test_check_estimate_allow() {
329 let config = ExplainConfig::default();
330 let est = ExplainEstimate {
331 total_cost: 100.0,
332 plan_rows: 500,
333 };
334 let decision = check_estimate(&est, &config);
335 assert!(!decision.is_rejected());
336 }
337
338 #[test]
339 fn test_check_estimate_reject_cost() {
340 let config = ExplainConfig::default();
341 let est = ExplainEstimate {
342 total_cost: 200_000.0,
343 plan_rows: 500,
344 };
345 let decision = check_estimate(&est, &config);
346 assert!(decision.is_rejected());
347 assert!(decision.rejection_message().unwrap().contains("200000"));
348 }
349
350 #[test]
351 fn test_check_estimate_reject_rows() {
352 let config = ExplainConfig::default();
353 let est = ExplainEstimate {
354 total_cost: 50.0,
355 plan_rows: 5_000_000,
356 };
357 let decision = check_estimate(&est, &config);
358 assert!(decision.is_rejected());
359 }
360
361 #[test]
362 fn test_cache_basic() {
363 let cache = ExplainCache::new(Duration::from_secs(60));
364 assert!(cache.is_empty());
365
366 cache.insert(
367 42,
368 ExplainEstimate {
369 total_cost: 100.0,
370 plan_rows: 50,
371 },
372 );
373 assert_eq!(cache.len(), 1);
374
375 let cached = cache.get(42, None).unwrap();
376 assert!((cached.total_cost - 100.0).abs() < 0.01);
377 assert_eq!(cached.plan_rows, 50);
378
379 assert!(cache.get(99, None).is_none());
381 }
382
383 #[test]
384 fn test_cache_expiry() {
385 let cache = ExplainCache::new(Duration::from_millis(1));
386 cache.insert(
387 1,
388 ExplainEstimate {
389 total_cost: 100.0,
390 plan_rows: 50,
391 },
392 );
393
394 std::thread::sleep(Duration::from_millis(5));
396 assert!(cache.get(1, None).is_none());
397 }
398
399 #[test]
400 fn test_cache_drift_invalidation() {
401 let cache = ExplainCache::new(Duration::from_secs(60));
402
403 cache.insert(
405 1,
406 ExplainEstimate {
407 total_cost: 50.0,
408 plan_rows: 1000,
409 },
410 );
411
412 assert!(cache.get(1, None).is_some());
414
415 assert!(cache.get(1, Some(1000)).is_some());
417
418 assert!(
420 cache.get(1, Some(1600)).is_some(),
421 "small table should not thrash"
422 );
423
424 assert!(
426 cache.get(1, Some(400)).is_some(),
427 "small shrinkage should not thrash"
428 );
429
430 cache.insert(
432 3,
433 ExplainEstimate {
434 total_cost: 500.0,
435 plan_rows: 50_000,
436 },
437 );
438
439 assert!(
441 cache.get(3, Some(85_000)).is_none(),
442 "large drift should invalidate"
443 );
444
445 assert!(
447 cache.get(3, Some(70_000)).is_some(),
448 "moderate drift should not invalidate"
449 );
450
451 assert!(
453 cache.get(3, Some(20_000)).is_none(),
454 "large shrinkage should invalidate"
455 );
456
457 cache.insert(
459 2,
460 ExplainEstimate {
461 total_cost: 10.0,
462 plan_rows: 0,
463 },
464 );
465 assert!(cache.get(2, Some(999_999)).is_some());
466 }
467
468 #[test]
469 fn test_explain_mode_default() {
470 let config = ExplainConfig::default();
471 assert_eq!(config.mode, ExplainMode::Precheck);
472 assert_eq!(config.depth_threshold, 3);
473 assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
474 assert_eq!(config.max_plan_rows, 1_000_000);
475 }
476}