1#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct CorrelationPredicate {
35 pub outer_col: String,
37 pub inner_col: String,
39 pub op: CorrelationOp,
41}
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum CorrelationOp {
46 Eq,
48 Lt,
50 Gt,
52}
53
54#[derive(Debug, Clone)]
56pub struct SubqueryAnalysis {
57 pub is_correlated: bool,
59 pub correlation_predicates: Vec<CorrelationPredicate>,
61 pub can_decorrelate: bool,
63 pub decorrelation_blocker: Option<DecorrelationBlocker>,
65 pub strategy: Option<DecorrelationStrategy>,
67}
68
69#[derive(Debug, Clone, PartialEq, Eq)]
71pub enum DecorrelationBlocker {
72 NonEqualityCorrelation,
74 CorrelationInLimit,
76 NestedCorrelation,
78 CorrelationInHaving,
80 RequiresLateralJoin,
82}
83
84#[derive(Debug, Clone, PartialEq, Eq)]
86pub enum DecorrelationStrategy {
87 JoinWithGroupBy {
90 group_by_cols: Vec<String>,
92 join_condition: Vec<(String, String)>,
94 },
95
96 LeftJoinWithGroupBy {
98 group_by_cols: Vec<String>,
99 join_condition: Vec<(String, String)>,
100 },
101
102 SemiJoin {
104 join_condition: Vec<(String, String)>,
105 },
106
107 AntiJoin {
109 join_condition: Vec<(String, String)>,
110 },
111
112 DistinctJoin {
115 join_condition: Vec<(String, String)>,
116 },
117}
118
119pub struct Decorrelator {
121 alias_counter: usize,
123}
124
125impl Decorrelator {
126 pub fn new() -> Self {
128 Self { alias_counter: 0 }
129 }
130
131 fn next_alias(&mut self) -> String {
133 self.alias_counter += 1;
134 format!("__derived_{}", self.alias_counter)
135 }
136
137 pub fn analyze(
139 &self,
140 outer_refs: &[String],
141 inner_cols: &[String],
142 correlation_predicates: &[(String, String)], subquery_type: SubqueryKind,
144 has_aggregation: bool,
145 has_limit: bool,
146 ) -> SubqueryAnalysis {
147 if outer_refs.is_empty() {
149 return SubqueryAnalysis {
150 is_correlated: false,
151 correlation_predicates: Vec::new(),
152 can_decorrelate: false,
153 decorrelation_blocker: None,
154 strategy: None,
155 };
156 }
157
158 let predicates: Vec<CorrelationPredicate> = correlation_predicates
160 .iter()
161 .map(|(outer, inner)| CorrelationPredicate {
162 outer_col: outer.clone(),
163 inner_col: inner.clone(),
164 op: CorrelationOp::Eq,
165 })
166 .collect();
167
168 if has_limit {
170 return SubqueryAnalysis {
171 is_correlated: true,
172 correlation_predicates: predicates,
173 can_decorrelate: false,
174 decorrelation_blocker: Some(DecorrelationBlocker::CorrelationInLimit),
175 strategy: None,
176 };
177 }
178
179 let strategy = match subquery_type {
181 SubqueryKind::Scalar if has_aggregation => {
182 let group_by_cols: Vec<String> =
184 predicates.iter().map(|p| p.inner_col.clone()).collect();
185 let join_condition: Vec<(String, String)> = predicates
186 .iter()
187 .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
188 .collect();
189
190 Some(DecorrelationStrategy::JoinWithGroupBy {
191 group_by_cols,
192 join_condition,
193 })
194 }
195 SubqueryKind::Scalar => {
196 let group_by_cols: Vec<String> =
198 predicates.iter().map(|p| p.inner_col.clone()).collect();
199 let join_condition: Vec<(String, String)> = predicates
200 .iter()
201 .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
202 .collect();
203
204 Some(DecorrelationStrategy::LeftJoinWithGroupBy {
205 group_by_cols,
206 join_condition,
207 })
208 }
209 SubqueryKind::Exists | SubqueryKind::In => {
210 let join_condition: Vec<(String, String)> = predicates
212 .iter()
213 .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
214 .collect();
215
216 Some(DecorrelationStrategy::SemiJoin { join_condition })
217 }
218 SubqueryKind::NotExists | SubqueryKind::NotIn => {
219 let join_condition: Vec<(String, String)> = predicates
221 .iter()
222 .map(|p| (p.outer_col.clone(), p.inner_col.clone()))
223 .collect();
224
225 Some(DecorrelationStrategy::AntiJoin { join_condition })
226 }
227 SubqueryKind::Any | SubqueryKind::All => {
228 None
231 }
232 };
233
234 SubqueryAnalysis {
235 is_correlated: true,
236 correlation_predicates: predicates,
237 can_decorrelate: strategy.is_some(),
238 decorrelation_blocker: if strategy.is_none() {
239 Some(DecorrelationBlocker::RequiresLateralJoin)
240 } else {
241 None
242 },
243 strategy,
244 }
245 }
246
247 pub fn estimate_speedup(
250 &self,
251 outer_cardinality: usize,
252 inner_cardinality: usize,
253 strategy: &DecorrelationStrategy,
254 ) -> f64 {
255 let correlated_cost = (outer_cardinality * inner_cardinality) as f64;
257
258 let decorrelated_cost = match strategy {
260 DecorrelationStrategy::JoinWithGroupBy { group_by_cols, .. } => {
261 let group_by_cost = inner_cardinality as f64 * (group_by_cols.len() as f64).log2();
263 let join_cost = (outer_cardinality + inner_cardinality) as f64;
264 group_by_cost + join_cost
265 }
266 DecorrelationStrategy::LeftJoinWithGroupBy { .. } => {
267 (outer_cardinality + inner_cardinality) as f64 * 1.5
269 }
270 DecorrelationStrategy::SemiJoin { .. } | DecorrelationStrategy::AntiJoin { .. } => {
271 (outer_cardinality + inner_cardinality) as f64
273 }
274 DecorrelationStrategy::DistinctJoin { .. } => {
275 let distinct_cost = inner_cardinality as f64 * 1.2;
277 let join_cost = (outer_cardinality + inner_cardinality) as f64;
278 distinct_cost + join_cost
279 }
280 };
281
282 if decorrelated_cost < 1.0 {
284 return correlated_cost;
285 }
286
287 correlated_cost / decorrelated_cost
288 }
289
290 pub fn should_decorrelate(
292 &self,
293 outer_cardinality: usize,
294 inner_cardinality: usize,
295 strategy: &DecorrelationStrategy,
296 ) -> bool {
297 let speedup = self.estimate_speedup(outer_cardinality, inner_cardinality, strategy);
299 speedup > 1.5
300 }
301}
302
303impl Default for Decorrelator {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309#[derive(Debug, Clone, Copy, PartialEq, Eq)]
311pub enum SubqueryKind {
312 Scalar,
314 Exists,
316 NotExists,
318 In,
320 NotIn,
322 Any,
324 All,
326}
327
328#[derive(Debug, Clone)]
334pub struct SubqueryRewrite {
335 pub derived_alias: String,
337 pub join_type: RewriteJoinType,
339 pub inner_select: Vec<String>,
341 pub group_by: Vec<String>,
343 pub join_on: Vec<(String, String)>,
345 pub result_col: Option<String>,
347}
348
349#[derive(Debug, Clone, Copy, PartialEq, Eq)]
351pub enum RewriteJoinType {
352 Inner,
353 Left,
354 Semi,
355 Anti,
356}
357
358impl Decorrelator {
359 pub fn plan_rewrite(
361 &mut self,
362 analysis: &SubqueryAnalysis,
363 aggregation_col: Option<&str>,
364 ) -> Option<SubqueryRewrite> {
365 let strategy = analysis.strategy.as_ref()?;
366
367 let alias = self.next_alias();
368
369 match strategy {
370 DecorrelationStrategy::JoinWithGroupBy {
371 group_by_cols,
372 join_condition,
373 } => {
374 let mut inner_select = group_by_cols.clone();
375 let result_col = aggregation_col.map(|c| {
376 let col_name = format!("__agg_{}", c);
377 inner_select.push(col_name.clone());
378 col_name
379 });
380
381 Some(SubqueryRewrite {
382 derived_alias: alias.clone(),
383 join_type: RewriteJoinType::Inner,
384 inner_select,
385 group_by: group_by_cols.clone(),
386 join_on: join_condition
387 .iter()
388 .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
389 .collect(),
390 result_col,
391 })
392 }
393 DecorrelationStrategy::LeftJoinWithGroupBy {
394 group_by_cols,
395 join_condition,
396 } => {
397 let mut inner_select = group_by_cols.clone();
398 let result_col = aggregation_col.map(|c| {
399 let col_name = format!("__agg_{}", c);
400 inner_select.push(col_name.clone());
401 col_name
402 });
403
404 Some(SubqueryRewrite {
405 derived_alias: alias.clone(),
406 join_type: RewriteJoinType::Left,
407 inner_select,
408 group_by: group_by_cols.clone(),
409 join_on: join_condition
410 .iter()
411 .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
412 .collect(),
413 result_col,
414 })
415 }
416 DecorrelationStrategy::SemiJoin { join_condition } => Some(SubqueryRewrite {
417 derived_alias: alias.clone(),
418 join_type: RewriteJoinType::Semi,
419 inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
420 group_by: Vec::new(),
421 join_on: join_condition
422 .iter()
423 .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
424 .collect(),
425 result_col: None,
426 }),
427 DecorrelationStrategy::AntiJoin { join_condition } => Some(SubqueryRewrite {
428 derived_alias: alias.clone(),
429 join_type: RewriteJoinType::Anti,
430 inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
431 group_by: Vec::new(),
432 join_on: join_condition
433 .iter()
434 .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
435 .collect(),
436 result_col: None,
437 }),
438 DecorrelationStrategy::DistinctJoin { join_condition } => {
439 Some(SubqueryRewrite {
440 derived_alias: alias.clone(),
441 join_type: RewriteJoinType::Semi,
442 inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
443 group_by: join_condition.iter().map(|(_, i)| i.clone()).collect(), join_on: join_condition
445 .iter()
446 .map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
447 .collect(),
448 result_col: None,
449 })
450 }
451 }
452 }
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn test_non_correlated() {
461 let decorrelator = Decorrelator::new();
462 let analysis = decorrelator.analyze(
463 &[], &["id".to_string(), "value".to_string()],
465 &[],
466 SubqueryKind::Scalar,
467 true,
468 false,
469 );
470
471 assert!(!analysis.is_correlated);
472 assert!(!analysis.can_decorrelate);
473 }
474
475 #[test]
476 fn test_scalar_aggregation_decorrelation() {
477 let decorrelator = Decorrelator::new();
478 let analysis = decorrelator.analyze(
479 &["o.customer_id".to_string()],
480 &["customer_id".to_string(), "total".to_string()],
481 &[("o.customer_id".to_string(), "customer_id".to_string())],
482 SubqueryKind::Scalar,
483 true, false, );
486
487 assert!(analysis.is_correlated);
488 assert!(analysis.can_decorrelate);
489 assert!(matches!(
490 analysis.strategy,
491 Some(DecorrelationStrategy::JoinWithGroupBy { .. })
492 ));
493 }
494
495 #[test]
496 fn test_exists_decorrelation() {
497 let decorrelator = Decorrelator::new();
498 let analysis = decorrelator.analyze(
499 &["o.id".to_string()],
500 &["order_id".to_string()],
501 &[("o.id".to_string(), "order_id".to_string())],
502 SubqueryKind::Exists,
503 false,
504 false,
505 );
506
507 assert!(analysis.is_correlated);
508 assert!(analysis.can_decorrelate);
509 assert!(matches!(
510 analysis.strategy,
511 Some(DecorrelationStrategy::SemiJoin { .. })
512 ));
513 }
514
515 #[test]
516 fn test_limit_blocks_decorrelation() {
517 let decorrelator = Decorrelator::new();
518 let analysis = decorrelator.analyze(
519 &["o.id".to_string()],
520 &["order_id".to_string()],
521 &[("o.id".to_string(), "order_id".to_string())],
522 SubqueryKind::Scalar,
523 false,
524 true, );
526
527 assert!(analysis.is_correlated);
528 assert!(!analysis.can_decorrelate);
529 assert_eq!(
530 analysis.decorrelation_blocker,
531 Some(DecorrelationBlocker::CorrelationInLimit)
532 );
533 }
534
535 #[test]
536 fn test_speedup_estimation() {
537 let decorrelator = Decorrelator::new();
538
539 let speedup = decorrelator.estimate_speedup(
543 1000,
544 1000,
545 &DecorrelationStrategy::SemiJoin {
546 join_condition: vec![("a".to_string(), "b".to_string())],
547 },
548 );
549
550 assert!(speedup > 100.0);
552 }
553
554 #[test]
555 fn test_rewrite_plan() {
556 let mut decorrelator = Decorrelator::new();
557
558 let analysis = decorrelator.analyze(
559 &["o.customer_id".to_string()],
560 &["customer_id".to_string(), "total".to_string()],
561 &[("o.customer_id".to_string(), "customer_id".to_string())],
562 SubqueryKind::Scalar,
563 true,
564 false,
565 );
566
567 let rewrite = decorrelator.plan_rewrite(&analysis, Some("avg_total"));
568 assert!(rewrite.is_some());
569
570 let rewrite = rewrite.unwrap();
571 assert_eq!(rewrite.join_type, RewriteJoinType::Inner);
572 assert!(rewrite.group_by.contains(&"customer_id".to_string()));
573 assert!(rewrite.result_col.is_some());
574 }
575}