use crate::error::DbxResult;
use crate::sql::planner::types::{AggregateMode, PhysicalPlan};
pub struct FragmentStage {
pub stage_id: usize,
pub plans: Vec<PhysicalPlan>,
}
pub struct FragmentDAG {
pub coordinator_plan: Option<PhysicalPlan>,
pub stages: Vec<FragmentStage>,
}
pub struct FragmentSplitter;
impl FragmentSplitter {
pub fn split(plan: PhysicalPlan) -> DbxResult<FragmentDAG> {
match Self::try_split(plan)? {
SplitResult::SplitDAG {
coordinator,
stages,
} => Ok(FragmentDAG {
coordinator_plan: Some(coordinator),
stages,
}),
SplitResult::Split {
coordinator,
worker,
} => Ok(FragmentDAG {
coordinator_plan: Some(coordinator),
stages: vec![FragmentStage {
stage_id: 1,
plans: vec![worker],
}],
}),
SplitResult::Unsplit(plan) => Ok(FragmentDAG {
coordinator_plan: None,
stages: vec![FragmentStage {
stage_id: 1,
plans: vec![plan],
}],
}),
}
}
fn try_split(plan: PhysicalPlan) -> DbxResult<SplitResult> {
match plan {
PhysicalPlan::HashAggregate {
input,
group_by,
aggregates,
mode: AggregateMode::Final,
} => {
let worker_plan = *input;
let coord_plan = PhysicalPlan::HashAggregate {
input: Box::new(PhysicalPlan::GridExchange {
exchange_id: 1, schema_hint: extract_output_columns(&worker_plan),
}),
group_by,
aggregates,
mode: AggregateMode::Final,
};
let worker_shuffle = PhysicalPlan::ShuffleWriter {
input: Box::new(worker_plan),
hash_params: vec![], target_nodes: vec![], exchange_id: 1,
salting: crate::sql::planner::types::ShuffleSalting::None,
};
Ok(SplitResult::SplitDAG {
coordinator: coord_plan,
stages: vec![FragmentStage {
stage_id: 1,
plans: vec![worker_shuffle],
}],
})
}
PhysicalPlan::HashJoin {
left,
right,
join_type,
on,
} => {
let left_worker = *left;
let right_worker = *right;
let coord_plan = PhysicalPlan::HashJoin {
left: Box::new(PhysicalPlan::GridExchange {
exchange_id: 1,
schema_hint: extract_output_columns(&left_worker),
}),
right: Box::new(PhysicalPlan::GridExchange {
exchange_id: 2,
schema_hint: extract_output_columns(&right_worker),
}),
join_type,
on,
};
let left_shuffle = PhysicalPlan::ShuffleWriter {
input: Box::new(left_worker),
hash_params: vec![], target_nodes: vec![],
exchange_id: 1,
salting: crate::sql::planner::types::ShuffleSalting::None,
};
let right_shuffle = PhysicalPlan::ShuffleWriter {
input: Box::new(right_worker),
hash_params: vec![],
target_nodes: vec![],
exchange_id: 2,
salting: crate::sql::planner::types::ShuffleSalting::None,
};
Ok(SplitResult::SplitDAG {
coordinator: coord_plan,
stages: vec![
FragmentStage {
stage_id: 1,
plans: vec![left_shuffle],
},
FragmentStage {
stage_id: 2,
plans: vec![right_shuffle],
},
],
})
}
PhysicalPlan::Projection {
input,
exprs,
aliases,
} => match Self::try_split(*input)? {
SplitResult::SplitDAG {
coordinator,
stages,
} => Ok(SplitResult::SplitDAG {
coordinator: PhysicalPlan::Projection {
input: Box::new(coordinator),
exprs,
aliases,
},
stages,
}),
SplitResult::Split {
coordinator,
worker,
} => Ok(SplitResult::Split {
coordinator: PhysicalPlan::Projection {
input: Box::new(coordinator),
exprs: exprs.clone(),
aliases: aliases.clone(),
},
worker,
}),
SplitResult::Unsplit(unchanged) => {
Ok(SplitResult::Unsplit(PhysicalPlan::Projection {
input: Box::new(unchanged),
exprs,
aliases,
}))
}
},
PhysicalPlan::Limit {
input,
count,
offset,
} => match Self::try_split(*input)? {
SplitResult::SplitDAG {
coordinator,
stages,
} => Ok(SplitResult::SplitDAG {
coordinator: PhysicalPlan::Limit {
input: Box::new(coordinator),
count,
offset,
},
stages,
}),
SplitResult::Split {
coordinator,
worker,
} => Ok(SplitResult::Split {
coordinator: PhysicalPlan::Limit {
input: Box::new(coordinator),
count,
offset,
},
worker,
}),
SplitResult::Unsplit(unchanged) => Ok(SplitResult::Unsplit(PhysicalPlan::Limit {
input: Box::new(unchanged),
count,
offset,
})),
},
PhysicalPlan::SortMerge { input, order_by } => match Self::try_split(*input)? {
SplitResult::SplitDAG {
coordinator,
stages,
} => Ok(SplitResult::SplitDAG {
coordinator: PhysicalPlan::SortMerge {
input: Box::new(coordinator),
order_by: order_by.clone(),
},
stages,
}),
SplitResult::Split {
coordinator,
worker,
} => Ok(SplitResult::Split {
coordinator: PhysicalPlan::SortMerge {
input: Box::new(coordinator),
order_by: order_by.clone(),
},
worker,
}),
SplitResult::Unsplit(unchanged) => {
Ok(SplitResult::Unsplit(PhysicalPlan::SortMerge {
input: Box::new(unchanged),
order_by,
}))
}
},
other => Ok(SplitResult::Unsplit(other)),
}
}
}
enum SplitResult {
SplitDAG {
coordinator: PhysicalPlan,
stages: Vec<FragmentStage>,
},
#[allow(dead_code)]
Split {
coordinator: PhysicalPlan,
worker: PhysicalPlan,
},
Unsplit(PhysicalPlan),
}
fn extract_output_columns(plan: &PhysicalPlan) -> usize {
match plan {
PhysicalPlan::HashAggregate {
group_by,
aggregates,
..
} => group_by.len() + aggregates.len(),
PhysicalPlan::Projection { exprs, .. } => exprs.len(),
PhysicalPlan::TableScan {
projection,
ros_files: _,
..
} => {
if projection.is_empty() {
8
} else {
projection.len()
}
}
_ => 4, }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::planner::types::{
AggregateFunction, AggregateMode, PhysicalAggExpr, PhysicalPlan,
};
fn make_partial_agg() -> PhysicalPlan {
PhysicalPlan::HashAggregate {
input: Box::new(PhysicalPlan::TableScan {
table: "sales".to_string(),
projection: vec![],
filter: None,
ros_files: vec![],
}),
group_by: vec![0],
aggregates: vec![PhysicalAggExpr {
function: AggregateFunction::Sum,
input: 1,
alias: Some("partial_sum".to_string()),
}],
mode: AggregateMode::Partial,
}
}
#[test]
fn test_split_final_over_partial_agg() {
let plan = PhysicalPlan::HashAggregate {
input: Box::new(make_partial_agg()),
group_by: vec![0],
aggregates: vec![PhysicalAggExpr {
function: AggregateFunction::Sum,
input: 1,
alias: Some("total_sum".to_string()),
}],
mode: AggregateMode::Final,
};
let dag = FragmentSplitter::split(plan).unwrap();
assert!(
dag.coordinator_plan.is_some(),
"coordinator_plan should be Some"
);
let coord = dag.coordinator_plan.unwrap();
assert!(matches!(
coord,
PhysicalPlan::HashAggregate {
mode: AggregateMode::Final,
..
}
));
if let PhysicalPlan::HashAggregate { input, .. } = &coord {
assert!(
matches!(**input, PhysicalPlan::GridExchange { .. }),
"coordinator input should be GridExchange"
);
}
assert_eq!(dag.stages.len(), 1, "Should have 1 stage for aggregation");
let worker_plan = &dag.stages[0].plans[0];
if let PhysicalPlan::ShuffleWriter { input, .. } = worker_plan {
assert!(matches!(
**input,
PhysicalPlan::HashAggregate {
mode: AggregateMode::Partial,
..
}
));
} else {
panic!("Expected ShuffleWriter");
}
}
#[test]
fn test_no_split_simple_scan() {
let plan = PhysicalPlan::TableScan {
table: "T1".to_string(),
projection: vec![],
filter: None,
ros_files: vec![],
};
let dag = FragmentSplitter::split(plan).unwrap();
assert!(
dag.coordinator_plan.is_none(),
"simple scan should not split"
);
assert_eq!(dag.stages.len(), 1);
assert_eq!(dag.stages[0].plans.len(), 1);
}
}