use super::{BuildSide, JoinAlgorithm};
pub const AQE_SWITCH_THRESHOLD: f64 = 10.0;
pub const MIN_ROWS_FOR_HASH_JOIN: u64 = 100;
pub const MAX_ROWS_FOR_NESTED_LOOP: u64 = 1000;
#[derive(Debug, Clone, PartialEq)]
pub enum AqeJoinDecision {
KeepPlanned,
SwitchToHashJoin { build_side: BuildSide },
SwitchToNestedLoop,
}
#[derive(Debug, Clone)]
pub struct JoinAqeContext {
pub estimated_left: u64,
pub estimated_right: u64,
pub actual_left: u64,
pub actual_right: u64,
pub has_equality_keys: bool,
pub planned_algorithm: JoinAlgorithm,
}
impl JoinAqeContext {
pub fn new(
estimated_left: u64,
estimated_right: u64,
actual_left: u64,
actual_right: u64,
has_equality_keys: bool,
planned_algorithm: JoinAlgorithm,
) -> Self {
Self {
estimated_left,
estimated_right,
actual_left,
actual_right,
has_equality_keys,
planned_algorithm,
}
}
pub fn estimation_error(&self) -> f64 {
let left_error = if self.estimated_left > 0 {
self.actual_left as f64 / self.estimated_left as f64
} else if self.actual_left > 0 {
f64::MAX
} else {
1.0
};
let right_error = if self.estimated_right > 0 {
self.actual_right as f64 / self.estimated_right as f64
} else if self.actual_right > 0 {
f64::MAX
} else {
1.0
};
let left_deviation = if left_error > 1.0 {
left_error
} else {
1.0 / left_error
};
let right_deviation = if right_error > 1.0 {
right_error
} else {
1.0 / right_error
};
left_deviation.max(right_deviation)
}
pub fn should_switch(&self) -> bool {
self.estimation_error() > AQE_SWITCH_THRESHOLD
}
}
pub fn decide_join_algorithm(ctx: &JoinAqeContext) -> AqeJoinDecision {
if !ctx.should_switch() {
return AqeJoinDecision::KeepPlanned;
}
let total_rows = ctx.actual_left + ctx.actual_right;
let cross_product = ctx.actual_left as u128 * ctx.actual_right as u128;
if ctx.actual_left <= MIN_ROWS_FOR_HASH_JOIN && ctx.actual_right <= MIN_ROWS_FOR_HASH_JOIN {
return match &ctx.planned_algorithm {
JoinAlgorithm::NestedLoop { .. } => AqeJoinDecision::KeepPlanned,
_ => AqeJoinDecision::SwitchToNestedLoop,
};
}
if !ctx.has_equality_keys {
return match &ctx.planned_algorithm {
JoinAlgorithm::NestedLoop { .. } => AqeJoinDecision::KeepPlanned,
_ => AqeJoinDecision::SwitchToNestedLoop,
};
}
if cross_product > (MAX_ROWS_FOR_NESTED_LOOP as u128 * MAX_ROWS_FOR_NESTED_LOOP as u128) {
let build_side = if ctx.actual_left <= ctx.actual_right {
BuildSide::Left
} else {
BuildSide::Right
};
return match &ctx.planned_algorithm {
JoinAlgorithm::HashJoin { build_side: bs, .. } if *bs == build_side => {
AqeJoinDecision::KeepPlanned
}
_ => AqeJoinDecision::SwitchToHashJoin { build_side },
};
}
if total_rows > MIN_ROWS_FOR_HASH_JOIN {
let build_side = if ctx.actual_left <= ctx.actual_right {
BuildSide::Left
} else {
BuildSide::Right
};
if let JoinAlgorithm::HashJoin { build_side: bs, .. } = &ctx.planned_algorithm {
if *bs == build_side {
return AqeJoinDecision::KeepPlanned;
}
return AqeJoinDecision::SwitchToHashJoin { build_side };
}
return AqeJoinDecision::SwitchToHashJoin { build_side };
}
AqeJoinDecision::KeepPlanned
}
#[derive(Debug)]
pub struct AqeJoinResult {
pub switched: bool,
pub original_algorithm: JoinAlgorithm,
pub used_algorithm: JoinAlgorithm,
pub actual_left_rows: u64,
pub actual_right_rows: u64,
}
impl AqeJoinResult {
pub fn has_significant_error(&self, estimated_left: u64, estimated_right: u64) -> bool {
let ctx = JoinAqeContext::new(
estimated_left,
estimated_right,
self.actual_left_rows,
self.actual_right_rows,
true, self.original_algorithm.clone(),
);
ctx.should_switch()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimation_error_calculation() {
let ctx = JoinAqeContext::new(
100,
100,
1000,
100,
true,
JoinAlgorithm::NestedLoop {
outer_rows: 100,
inner_rows: 100,
},
);
assert!((ctx.estimation_error() - 10.0).abs() < 0.001);
}
#[test]
fn test_should_switch_above_threshold() {
let ctx = JoinAqeContext::new(
100,
100,
1100, 100,
true,
JoinAlgorithm::NestedLoop {
outer_rows: 100,
inner_rows: 100,
},
);
assert!(ctx.should_switch());
}
#[test]
fn test_should_not_switch_below_threshold() {
let ctx = JoinAqeContext::new(
100,
100,
500, 100,
true,
JoinAlgorithm::NestedLoop {
outer_rows: 100,
inner_rows: 100,
},
);
assert!(!ctx.should_switch());
}
#[test]
fn test_small_inputs_prefer_nested_loop() {
let ctx = JoinAqeContext::new(
1000, 1000,
50, 50,
true,
JoinAlgorithm::HashJoin {
build_side: BuildSide::Left,
build_rows: 1000,
probe_rows: 1000,
},
);
assert!(ctx.should_switch());
let decision = decide_join_algorithm(&ctx);
assert_eq!(decision, AqeJoinDecision::SwitchToNestedLoop);
}
#[test]
fn test_large_inputs_prefer_hash_join() {
let ctx = JoinAqeContext::new(
100, 100,
10000, 10000,
true,
JoinAlgorithm::NestedLoop {
outer_rows: 100,
inner_rows: 100,
},
);
let decision = decide_join_algorithm(&ctx);
match decision {
AqeJoinDecision::SwitchToHashJoin { .. } => (),
_ => panic!("Expected switch to hash join"),
}
}
#[test]
fn test_no_equality_keys_must_use_nested_loop() {
let ctx = JoinAqeContext::new(
100,
100,
10000,
10000,
false, JoinAlgorithm::NestedLoop {
outer_rows: 100,
inner_rows: 100,
},
);
let decision = decide_join_algorithm(&ctx);
assert!(matches!(
decision,
AqeJoinDecision::KeepPlanned | AqeJoinDecision::SwitchToNestedLoop
));
}
#[test]
fn test_hash_join_build_side_optimization() {
let ctx = JoinAqeContext::new(
1000,
100, 500, 5000, true,
JoinAlgorithm::HashJoin {
build_side: BuildSide::Right, build_rows: 100,
probe_rows: 1000,
},
);
let decision = decide_join_algorithm(&ctx);
assert_eq!(
decision,
AqeJoinDecision::SwitchToHashJoin {
build_side: BuildSide::Left
}
);
}
}