#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CorrelationPredicate {
pub outer_col: String,
pub inner_col: String,
pub op: CorrelationOp,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CorrelationOp {
Eq,
Lt,
Gt,
}
#[derive(Debug, Clone)]
pub struct SubqueryAnalysis {
pub is_correlated: bool,
pub correlation_predicates: Vec<CorrelationPredicate>,
pub can_decorrelate: bool,
pub decorrelation_blocker: Option<DecorrelationBlocker>,
pub strategy: Option<DecorrelationStrategy>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecorrelationBlocker {
NonEqualityCorrelation,
CorrelationInLimit,
NestedCorrelation,
CorrelationInHaving,
RequiresLateralJoin,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecorrelationStrategy {
JoinWithGroupBy {
group_by_cols: Vec<String>,
join_condition: Vec<(String, String)>,
},
LeftJoinWithGroupBy {
group_by_cols: Vec<String>,
join_condition: Vec<(String, String)>,
},
SemiJoin {
join_condition: Vec<(String, String)>,
},
AntiJoin {
join_condition: Vec<(String, String)>,
},
DistinctJoin {
join_condition: Vec<(String, String)>,
},
}
pub struct Decorrelator {
alias_counter: usize,
}
impl Decorrelator {
pub fn new() -> Self {
Self { alias_counter: 0 }
}
fn next_alias(&mut self) -> String {
self.alias_counter += 1;
format!("__derived_{}", self.alias_counter)
}
pub fn analyze(
&self,
outer_refs: &[String],
inner_cols: &[String],
correlation_predicates: &[(String, String)], subquery_type: SubqueryKind,
has_aggregation: bool,
has_limit: bool,
) -> SubqueryAnalysis {
if outer_refs.is_empty() {
return SubqueryAnalysis {
is_correlated: false,
correlation_predicates: Vec::new(),
can_decorrelate: false,
decorrelation_blocker: None,
strategy: None,
};
}
let predicates: Vec<CorrelationPredicate> = correlation_predicates
.iter()
.map(|(outer, inner)| CorrelationPredicate {
outer_col: outer.clone(),
inner_col: inner.clone(),
op: CorrelationOp::Eq,
})
.collect();
if has_limit {
return SubqueryAnalysis {
is_correlated: true,
correlation_predicates: predicates,
can_decorrelate: false,
decorrelation_blocker: Some(DecorrelationBlocker::CorrelationInLimit),
strategy: None,
};
}
let strategy = match subquery_type {
SubqueryKind::Scalar if has_aggregation => {
let group_by_cols: Vec<String> =
predicates.iter().map(|p| p.inner_col.clone()).collect();
let join_condition: Vec<(String, String)> = predicates
.iter()
.map(|p| (p.outer_col.clone(), p.inner_col.clone()))
.collect();
Some(DecorrelationStrategy::JoinWithGroupBy {
group_by_cols,
join_condition,
})
}
SubqueryKind::Scalar => {
let group_by_cols: Vec<String> =
predicates.iter().map(|p| p.inner_col.clone()).collect();
let join_condition: Vec<(String, String)> = predicates
.iter()
.map(|p| (p.outer_col.clone(), p.inner_col.clone()))
.collect();
Some(DecorrelationStrategy::LeftJoinWithGroupBy {
group_by_cols,
join_condition,
})
}
SubqueryKind::Exists | SubqueryKind::In => {
let join_condition: Vec<(String, String)> = predicates
.iter()
.map(|p| (p.outer_col.clone(), p.inner_col.clone()))
.collect();
Some(DecorrelationStrategy::SemiJoin { join_condition })
}
SubqueryKind::NotExists | SubqueryKind::NotIn => {
let join_condition: Vec<(String, String)> = predicates
.iter()
.map(|p| (p.outer_col.clone(), p.inner_col.clone()))
.collect();
Some(DecorrelationStrategy::AntiJoin { join_condition })
}
SubqueryKind::Any | SubqueryKind::All => {
None
}
};
SubqueryAnalysis {
is_correlated: true,
correlation_predicates: predicates,
can_decorrelate: strategy.is_some(),
decorrelation_blocker: if strategy.is_none() {
Some(DecorrelationBlocker::RequiresLateralJoin)
} else {
None
},
strategy,
}
}
pub fn estimate_speedup(
&self,
outer_cardinality: usize,
inner_cardinality: usize,
strategy: &DecorrelationStrategy,
) -> f64 {
let correlated_cost = (outer_cardinality * inner_cardinality) as f64;
let decorrelated_cost = match strategy {
DecorrelationStrategy::JoinWithGroupBy { group_by_cols, .. } => {
let group_by_cost = inner_cardinality as f64 * (group_by_cols.len() as f64).log2();
let join_cost = (outer_cardinality + inner_cardinality) as f64;
group_by_cost + join_cost
}
DecorrelationStrategy::LeftJoinWithGroupBy { .. } => {
(outer_cardinality + inner_cardinality) as f64 * 1.5
}
DecorrelationStrategy::SemiJoin { .. } | DecorrelationStrategy::AntiJoin { .. } => {
(outer_cardinality + inner_cardinality) as f64
}
DecorrelationStrategy::DistinctJoin { .. } => {
let distinct_cost = inner_cardinality as f64 * 1.2;
let join_cost = (outer_cardinality + inner_cardinality) as f64;
distinct_cost + join_cost
}
};
if decorrelated_cost < 1.0 {
return correlated_cost;
}
correlated_cost / decorrelated_cost
}
pub fn should_decorrelate(
&self,
outer_cardinality: usize,
inner_cardinality: usize,
strategy: &DecorrelationStrategy,
) -> bool {
let speedup = self.estimate_speedup(outer_cardinality, inner_cardinality, strategy);
speedup > 1.5
}
}
impl Default for Decorrelator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SubqueryKind {
Scalar,
Exists,
NotExists,
In,
NotIn,
Any,
All,
}
#[derive(Debug, Clone)]
pub struct SubqueryRewrite {
pub derived_alias: String,
pub join_type: RewriteJoinType,
pub inner_select: Vec<String>,
pub group_by: Vec<String>,
pub join_on: Vec<(String, String)>,
pub result_col: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RewriteJoinType {
Inner,
Left,
Semi,
Anti,
}
impl Decorrelator {
pub fn plan_rewrite(
&mut self,
analysis: &SubqueryAnalysis,
aggregation_col: Option<&str>,
) -> Option<SubqueryRewrite> {
let strategy = analysis.strategy.as_ref()?;
let alias = self.next_alias();
match strategy {
DecorrelationStrategy::JoinWithGroupBy {
group_by_cols,
join_condition,
} => {
let mut inner_select = group_by_cols.clone();
let result_col = aggregation_col.map(|c| {
let col_name = format!("__agg_{}", c);
inner_select.push(col_name.clone());
col_name
});
Some(SubqueryRewrite {
derived_alias: alias.clone(),
join_type: RewriteJoinType::Inner,
inner_select,
group_by: group_by_cols.clone(),
join_on: join_condition
.iter()
.map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
.collect(),
result_col,
})
}
DecorrelationStrategy::LeftJoinWithGroupBy {
group_by_cols,
join_condition,
} => {
let mut inner_select = group_by_cols.clone();
let result_col = aggregation_col.map(|c| {
let col_name = format!("__agg_{}", c);
inner_select.push(col_name.clone());
col_name
});
Some(SubqueryRewrite {
derived_alias: alias.clone(),
join_type: RewriteJoinType::Left,
inner_select,
group_by: group_by_cols.clone(),
join_on: join_condition
.iter()
.map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
.collect(),
result_col,
})
}
DecorrelationStrategy::SemiJoin { join_condition } => Some(SubqueryRewrite {
derived_alias: alias.clone(),
join_type: RewriteJoinType::Semi,
inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
group_by: Vec::new(),
join_on: join_condition
.iter()
.map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
.collect(),
result_col: None,
}),
DecorrelationStrategy::AntiJoin { join_condition } => Some(SubqueryRewrite {
derived_alias: alias.clone(),
join_type: RewriteJoinType::Anti,
inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
group_by: Vec::new(),
join_on: join_condition
.iter()
.map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
.collect(),
result_col: None,
}),
DecorrelationStrategy::DistinctJoin { join_condition } => {
Some(SubqueryRewrite {
derived_alias: alias.clone(),
join_type: RewriteJoinType::Semi,
inner_select: join_condition.iter().map(|(_, i)| i.clone()).collect(),
group_by: join_condition.iter().map(|(_, i)| i.clone()).collect(), join_on: join_condition
.iter()
.map(|(o, i)| (o.clone(), format!("{}.{}", alias, i)))
.collect(),
result_col: None,
})
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_non_correlated() {
let decorrelator = Decorrelator::new();
let analysis = decorrelator.analyze(
&[], &["id".to_string(), "value".to_string()],
&[],
SubqueryKind::Scalar,
true,
false,
);
assert!(!analysis.is_correlated);
assert!(!analysis.can_decorrelate);
}
#[test]
fn test_scalar_aggregation_decorrelation() {
let decorrelator = Decorrelator::new();
let analysis = decorrelator.analyze(
&["o.customer_id".to_string()],
&["customer_id".to_string(), "total".to_string()],
&[("o.customer_id".to_string(), "customer_id".to_string())],
SubqueryKind::Scalar,
true, false, );
assert!(analysis.is_correlated);
assert!(analysis.can_decorrelate);
assert!(matches!(
analysis.strategy,
Some(DecorrelationStrategy::JoinWithGroupBy { .. })
));
}
#[test]
fn test_exists_decorrelation() {
let decorrelator = Decorrelator::new();
let analysis = decorrelator.analyze(
&["o.id".to_string()],
&["order_id".to_string()],
&[("o.id".to_string(), "order_id".to_string())],
SubqueryKind::Exists,
false,
false,
);
assert!(analysis.is_correlated);
assert!(analysis.can_decorrelate);
assert!(matches!(
analysis.strategy,
Some(DecorrelationStrategy::SemiJoin { .. })
));
}
#[test]
fn test_limit_blocks_decorrelation() {
let decorrelator = Decorrelator::new();
let analysis = decorrelator.analyze(
&["o.id".to_string()],
&["order_id".to_string()],
&[("o.id".to_string(), "order_id".to_string())],
SubqueryKind::Scalar,
false,
true, );
assert!(analysis.is_correlated);
assert!(!analysis.can_decorrelate);
assert_eq!(
analysis.decorrelation_blocker,
Some(DecorrelationBlocker::CorrelationInLimit)
);
}
#[test]
fn test_speedup_estimation() {
let decorrelator = Decorrelator::new();
let speedup = decorrelator.estimate_speedup(
1000,
1000,
&DecorrelationStrategy::SemiJoin {
join_condition: vec![("a".to_string(), "b".to_string())],
},
);
assert!(speedup > 100.0);
}
#[test]
fn test_rewrite_plan() {
let mut decorrelator = Decorrelator::new();
let analysis = decorrelator.analyze(
&["o.customer_id".to_string()],
&["customer_id".to_string(), "total".to_string()],
&[("o.customer_id".to_string(), "customer_id".to_string())],
SubqueryKind::Scalar,
true,
false,
);
let rewrite = decorrelator.plan_rewrite(&analysis, Some("avg_total"));
assert!(rewrite.is_some());
let rewrite = rewrite.unwrap();
assert_eq!(rewrite.join_type, RewriteJoinType::Inner);
assert!(rewrite.group_by.contains(&"customer_id".to_string()));
assert!(rewrite.result_col.is_some());
}
}