1use std::collections::HashMap;
16use std::sync::Mutex;
17use std::time::{Duration, Instant};
18
19#[derive(Debug, Clone)]
21pub struct ExplainConfig {
22 pub mode: ExplainMode,
24
25 pub depth_threshold: usize,
28
29 pub max_total_cost: f64,
32
33 pub max_plan_rows: u64,
36
37 pub cache_ttl: Duration,
40}
41
42impl Default for ExplainConfig {
43 fn default() -> Self {
44 Self {
45 mode: ExplainMode::Precheck,
46 depth_threshold: 3,
47 max_total_cost: 100_000.0,
48 max_plan_rows: 1_000_000,
49 cache_ttl: Duration::from_secs(300),
50 }
51 }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum ExplainMode {
57 Off,
59 Precheck,
62 Enforce,
65}
66
67#[derive(Debug, Clone)]
69pub struct ExplainEstimate {
70 pub total_cost: f64,
72 pub plan_rows: u64,
74}
75
76struct CachedEstimate {
78 estimate: ExplainEstimate,
79 cached_at: Instant,
80 plan_rows: u64,
82}
83
84pub struct ExplainCache {
86 entries: Mutex<HashMap<u64, CachedEstimate>>,
87 ttl: Duration,
88 max_entries: usize,
90}
91
92impl ExplainCache {
93 pub fn new(ttl: Duration) -> Self {
95 Self {
96 entries: Mutex::new(HashMap::new()),
97 ttl,
98 max_entries: 10_000,
99 }
100 }
101
102 pub fn get(&self, shape_hash: u64, current_reltuples: Option<u64>) -> Option<ExplainEstimate> {
109 let entries = self.entries.lock().ok()?;
110 let entry = entries.get(&shape_hash)?;
111 if entry.cached_at.elapsed() < self.ttl {
112 if let Some(current) = current_reltuples
116 && entry.plan_rows > 0
117 {
118 let cached = entry.plan_rows as f64;
119 let drift = ((current as f64) - cached).abs() / cached;
120 let abs_delta = (current as i64 - entry.plan_rows as i64).unsigned_abs();
121 if drift > 0.5 && abs_delta > 10_000 {
122 return None; }
124 }
125 Some(entry.estimate.clone())
126 } else {
127 None
128 }
129 }
130
131 pub fn insert(&self, shape_hash: u64, estimate: ExplainEstimate) {
133 if let Ok(mut entries) = self.entries.lock() {
134 if entries.len() >= self.max_entries / 2 {
136 let ttl = self.ttl;
137 entries.retain(|_, v| v.cached_at.elapsed() < ttl);
138 }
139 if entries.len() >= self.max_entries {
141 return;
142 }
143 entries.insert(shape_hash, CachedEstimate {
144 plan_rows: estimate.plan_rows,
145 estimate,
146 cached_at: Instant::now(),
147 });
148 }
149 }
150
151 pub fn len(&self) -> usize {
153 self.entries.lock().map(|e| e.len()).unwrap_or(0)
154 }
155
156 pub fn is_empty(&self) -> bool {
158 self.len() == 0
159 }
160}
161
162pub fn parse_explain_json(json_str: &str) -> Option<ExplainEstimate> {
170 let total_cost = extract_json_number(json_str, "Total Cost")?;
171 let plan_rows = extract_json_number(json_str, "Plan Rows")? as u64;
172
173 Some(ExplainEstimate {
174 total_cost,
175 plan_rows,
176 })
177}
178
179fn extract_json_number(json: &str, key: &str) -> Option<f64> {
181 let pattern = format!("\"{}\":", key);
182 let start = json.find(&pattern)?;
183 let after_key = &json[start + pattern.len()..];
184
185 let trimmed = after_key.trim_start();
187
188 let end = trimmed.find(|c: char| !c.is_ascii_digit() && c != '.' && c != '-' && c != 'e' && c != 'E' && c != '+')?;
190 let num_str = &trimmed[..end];
191 num_str.parse::<f64>().ok()
192}
193
194#[derive(Debug)]
196pub enum ExplainDecision {
197 Allow,
199 Reject {
201 total_cost: f64,
202 plan_rows: u64,
203 max_cost: f64,
204 max_rows: u64,
205 },
206 Skipped,
208}
209
210impl ExplainDecision {
211 pub fn is_rejected(&self) -> bool {
213 matches!(self, ExplainDecision::Reject { .. })
214 }
215
216 pub fn rejection_message(&self) -> Option<String> {
218 match self {
219 ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
220 Some(format!(
221 "Query rejected: estimated cost {:.0} exceeds limit {:.0}, \
222 or estimated rows {} exceeds limit {}. \
223 Try narrowing your filters, reducing ?expand depth, or using pagination.",
224 total_cost, max_cost, plan_rows, max_rows
225 ))
226 }
227 _ => None,
228 }
229 }
230
231 pub fn rejection_detail(&self) -> Option<ExplainRejectionDetail> {
236 match self {
237 ExplainDecision::Reject { total_cost, plan_rows, max_cost, max_rows } => {
238 Some(ExplainRejectionDetail {
239 estimated_cost: *total_cost,
240 cost_limit: *max_cost,
241 estimated_rows: *plan_rows,
242 row_limit: *max_rows,
243 suggestions: vec![
244 "Add WHERE clauses to narrow the result set".to_string(),
245 "Reduce ?expand depth (deep JOINs multiply cost)".to_string(),
246 "Use ?limit and ?offset for pagination".to_string(),
247 "Add indexes on frequently filtered columns".to_string(),
248 ],
249 })
250 }
251 _ => None,
252 }
253 }
254}
255
256#[derive(Debug, Clone)]
258pub struct ExplainRejectionDetail {
259 pub estimated_cost: f64,
260 pub cost_limit: f64,
261 pub estimated_rows: u64,
262 pub row_limit: u64,
263 pub suggestions: Vec<String>,
264}
265
266pub fn check_estimate(estimate: &ExplainEstimate, config: &ExplainConfig) -> ExplainDecision {
268 if estimate.total_cost > config.max_total_cost || estimate.plan_rows > config.max_plan_rows {
269 ExplainDecision::Reject {
270 total_cost: estimate.total_cost,
271 plan_rows: estimate.plan_rows,
272 max_cost: config.max_total_cost,
273 max_rows: config.max_plan_rows,
274 }
275 } else {
276 ExplainDecision::Allow
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_parse_explain_json() {
286 let json = r#"[{"Plan": {"Node Type": "Seq Scan", "Total Cost": 1234.56, "Plan Rows": 5000, "Plan Width": 100}}]"#;
287 let est = parse_explain_json(json).unwrap();
288 assert!((est.total_cost - 1234.56).abs() < 0.01);
289 assert_eq!(est.plan_rows, 5000);
290 }
291
292 #[test]
293 fn test_parse_explain_json_nested_join() {
294 let json = r#"[{"Plan": {"Node Type": "Hash Join", "Total Cost": 250000.0, "Plan Rows": 2000000, "Plan Width": 200}}]"#;
295 let est = parse_explain_json(json).unwrap();
296 assert!((est.total_cost - 250000.0).abs() < 0.01);
297 assert_eq!(est.plan_rows, 2_000_000);
298 }
299
300 #[test]
301 fn test_parse_explain_json_invalid() {
302 assert!(parse_explain_json("not json").is_none());
303 assert!(parse_explain_json("{}").is_none());
304 assert!(parse_explain_json("[]").is_none());
305 }
306
307 #[test]
308 fn test_check_estimate_allow() {
309 let config = ExplainConfig::default();
310 let est = ExplainEstimate { total_cost: 100.0, plan_rows: 500 };
311 let decision = check_estimate(&est, &config);
312 assert!(!decision.is_rejected());
313 }
314
315 #[test]
316 fn test_check_estimate_reject_cost() {
317 let config = ExplainConfig::default();
318 let est = ExplainEstimate { total_cost: 200_000.0, plan_rows: 500 };
319 let decision = check_estimate(&est, &config);
320 assert!(decision.is_rejected());
321 assert!(decision.rejection_message().unwrap().contains("200000"));
322 }
323
324 #[test]
325 fn test_check_estimate_reject_rows() {
326 let config = ExplainConfig::default();
327 let est = ExplainEstimate { total_cost: 50.0, plan_rows: 5_000_000 };
328 let decision = check_estimate(&est, &config);
329 assert!(decision.is_rejected());
330 }
331
332 #[test]
333 fn test_cache_basic() {
334 let cache = ExplainCache::new(Duration::from_secs(60));
335 assert!(cache.is_empty());
336
337 cache.insert(42, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
338 assert_eq!(cache.len(), 1);
339
340 let cached = cache.get(42, None).unwrap();
341 assert!((cached.total_cost - 100.0).abs() < 0.01);
342 assert_eq!(cached.plan_rows, 50);
343
344 assert!(cache.get(99, None).is_none());
346 }
347
348 #[test]
349 fn test_cache_expiry() {
350 let cache = ExplainCache::new(Duration::from_millis(1));
351 cache.insert(1, ExplainEstimate { total_cost: 100.0, plan_rows: 50 });
352
353 std::thread::sleep(Duration::from_millis(5));
355 assert!(cache.get(1, None).is_none());
356 }
357
358 #[test]
359 fn test_cache_drift_invalidation() {
360 let cache = ExplainCache::new(Duration::from_secs(60));
361
362 cache.insert(1, ExplainEstimate { total_cost: 50.0, plan_rows: 1000 });
364
365 assert!(cache.get(1, None).is_some());
367
368 assert!(cache.get(1, Some(1000)).is_some());
370
371 assert!(cache.get(1, Some(1600)).is_some(), "small table should not thrash");
373
374 assert!(cache.get(1, Some(400)).is_some(), "small shrinkage should not thrash");
376
377 cache.insert(3, ExplainEstimate { total_cost: 500.0, plan_rows: 50_000 });
379
380 assert!(cache.get(3, Some(85_000)).is_none(), "large drift should invalidate");
382
383 assert!(cache.get(3, Some(70_000)).is_some(), "moderate drift should not invalidate");
385
386 assert!(cache.get(3, Some(20_000)).is_none(), "large shrinkage should invalidate");
388
389 cache.insert(2, ExplainEstimate { total_cost: 10.0, plan_rows: 0 });
391 assert!(cache.get(2, Some(999_999)).is_some());
392 }
393
394 #[test]
395 fn test_explain_mode_default() {
396 let config = ExplainConfig::default();
397 assert_eq!(config.mode, ExplainMode::Precheck);
398 assert_eq!(config.depth_threshold, 3);
399 assert!((config.max_total_cost - 100_000.0).abs() < 0.01);
400 assert_eq!(config.max_plan_rows, 1_000_000);
401 }
402}