1use crate::error::DbxResult;
15use crate::sql::planner::types::{AggregateMode, PhysicalPlan};
16
17pub struct FragmentStage {
18 pub stage_id: usize,
19 pub plans: Vec<PhysicalPlan>,
20}
21
22pub struct FragmentDAG {
24 pub coordinator_plan: Option<PhysicalPlan>,
27 pub stages: Vec<FragmentStage>,
29}
30
31pub struct FragmentSplitter;
32
33impl FragmentSplitter {
34 pub fn split(plan: PhysicalPlan) -> DbxResult<FragmentDAG> {
39 match Self::try_split(plan)? {
40 SplitResult::SplitDAG {
41 coordinator,
42 stages,
43 } => Ok(FragmentDAG {
44 coordinator_plan: Some(coordinator),
45 stages,
46 }),
47 SplitResult::Split {
48 coordinator,
49 worker,
50 } => Ok(FragmentDAG {
51 coordinator_plan: Some(coordinator),
52 stages: vec![FragmentStage {
53 stage_id: 1,
54 plans: vec![worker],
55 }],
56 }),
57 SplitResult::Unsplit(plan) => Ok(FragmentDAG {
58 coordinator_plan: None,
59 stages: vec![FragmentStage {
60 stage_id: 1,
61 plans: vec![plan],
62 }],
63 }),
64 }
65 }
66
67 fn try_split(plan: PhysicalPlan) -> DbxResult<SplitResult> {
68 match plan {
69 PhysicalPlan::HashAggregate {
74 input,
75 group_by,
76 aggregates,
77 mode: AggregateMode::Final,
78 } => {
79 let worker_plan = *input;
81
82 let coord_plan = PhysicalPlan::HashAggregate {
84 input: Box::new(PhysicalPlan::GridExchange {
85 exchange_id: 1, schema_hint: extract_output_columns(&worker_plan),
87 }),
88 group_by,
89 aggregates,
90 mode: AggregateMode::Final,
91 };
92
93 let worker_shuffle = PhysicalPlan::ShuffleWriter {
95 input: Box::new(worker_plan),
96 hash_params: vec![], target_nodes: vec![], exchange_id: 1,
99 salting: crate::sql::planner::types::ShuffleSalting::None,
100 };
101
102 Ok(SplitResult::SplitDAG {
103 coordinator: coord_plan,
104 stages: vec![FragmentStage {
105 stage_id: 1,
106 plans: vec![worker_shuffle],
107 }],
108 })
109 }
110
111 PhysicalPlan::HashJoin {
113 left,
114 right,
115 join_type,
116 on,
117 } => {
118 let left_worker = *left;
119 let right_worker = *right;
120
121 let coord_plan = PhysicalPlan::HashJoin {
123 left: Box::new(PhysicalPlan::GridExchange {
124 exchange_id: 1,
125 schema_hint: extract_output_columns(&left_worker),
126 }),
127 right: Box::new(PhysicalPlan::GridExchange {
128 exchange_id: 2,
129 schema_hint: extract_output_columns(&right_worker),
130 }),
131 join_type,
132 on,
133 };
134
135 let left_shuffle = PhysicalPlan::ShuffleWriter {
137 input: Box::new(left_worker),
138 hash_params: vec![], target_nodes: vec![],
140 exchange_id: 1,
141 salting: crate::sql::planner::types::ShuffleSalting::None,
142 };
143
144 let right_shuffle = PhysicalPlan::ShuffleWriter {
146 input: Box::new(right_worker),
147 hash_params: vec![],
148 target_nodes: vec![],
149 exchange_id: 2,
150 salting: crate::sql::planner::types::ShuffleSalting::None,
151 };
152
153 Ok(SplitResult::SplitDAG {
154 coordinator: coord_plan,
155 stages: vec![
156 FragmentStage {
157 stage_id: 1,
158 plans: vec![left_shuffle],
159 },
160 FragmentStage {
161 stage_id: 2,
162 plans: vec![right_shuffle],
163 },
164 ],
165 })
166 }
167
168 PhysicalPlan::Projection {
170 input,
171 exprs,
172 aliases,
173 } => match Self::try_split(*input)? {
174 SplitResult::SplitDAG {
175 coordinator,
176 stages,
177 } => Ok(SplitResult::SplitDAG {
178 coordinator: PhysicalPlan::Projection {
179 input: Box::new(coordinator),
180 exprs,
181 aliases,
182 },
183 stages,
184 }),
185 SplitResult::Split {
186 coordinator,
187 worker,
188 } => Ok(SplitResult::Split {
189 coordinator: PhysicalPlan::Projection {
190 input: Box::new(coordinator),
191 exprs: exprs.clone(),
192 aliases: aliases.clone(),
193 },
194 worker,
195 }),
196 SplitResult::Unsplit(unchanged) => {
197 Ok(SplitResult::Unsplit(PhysicalPlan::Projection {
198 input: Box::new(unchanged),
199 exprs,
200 aliases,
201 }))
202 }
203 },
204
205 PhysicalPlan::Limit {
206 input,
207 count,
208 offset,
209 } => match Self::try_split(*input)? {
210 SplitResult::SplitDAG {
211 coordinator,
212 stages,
213 } => Ok(SplitResult::SplitDAG {
214 coordinator: PhysicalPlan::Limit {
215 input: Box::new(coordinator),
216 count,
217 offset,
218 },
219 stages,
220 }),
221 SplitResult::Split {
222 coordinator,
223 worker,
224 } => Ok(SplitResult::Split {
225 coordinator: PhysicalPlan::Limit {
226 input: Box::new(coordinator),
227 count,
228 offset,
229 },
230 worker,
231 }),
232 SplitResult::Unsplit(unchanged) => Ok(SplitResult::Unsplit(PhysicalPlan::Limit {
233 input: Box::new(unchanged),
234 count,
235 offset,
236 })),
237 },
238
239 PhysicalPlan::SortMerge { input, order_by } => match Self::try_split(*input)? {
240 SplitResult::SplitDAG {
241 coordinator,
242 stages,
243 } => Ok(SplitResult::SplitDAG {
244 coordinator: PhysicalPlan::SortMerge {
245 input: Box::new(coordinator),
246 order_by: order_by.clone(),
247 },
248 stages,
249 }),
250 SplitResult::Split {
251 coordinator,
252 worker,
253 } => Ok(SplitResult::Split {
254 coordinator: PhysicalPlan::SortMerge {
255 input: Box::new(coordinator),
256 order_by: order_by.clone(),
257 },
258 worker,
259 }),
260 SplitResult::Unsplit(unchanged) => {
261 Ok(SplitResult::Unsplit(PhysicalPlan::SortMerge {
262 input: Box::new(unchanged),
263 order_by,
264 }))
265 }
266 },
267
268 other => Ok(SplitResult::Unsplit(other)),
270 }
271 }
272}
273
274enum SplitResult {
276 SplitDAG {
277 coordinator: PhysicalPlan,
278 stages: Vec<FragmentStage>,
279 },
280 #[allow(dead_code)]
281 Split {
282 coordinator: PhysicalPlan,
283 worker: PhysicalPlan,
284 },
285 Unsplit(PhysicalPlan),
286}
287
288fn extract_output_columns(plan: &PhysicalPlan) -> usize {
290 match plan {
291 PhysicalPlan::HashAggregate {
292 group_by,
293 aggregates,
294 ..
295 } => group_by.len() + aggregates.len(),
296 PhysicalPlan::Projection { exprs, .. } => exprs.len(),
297 PhysicalPlan::TableScan {
298 projection,
299 ros_files: _,
300 ..
301 } => {
302 if projection.is_empty() {
303 8
304 } else {
305 projection.len()
306 }
307 }
308 _ => 4, }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::sql::planner::types::{
316 AggregateFunction, AggregateMode, PhysicalAggExpr, PhysicalPlan,
317 };
318
319 fn make_partial_agg() -> PhysicalPlan {
320 PhysicalPlan::HashAggregate {
321 input: Box::new(PhysicalPlan::TableScan {
322 table: "sales".to_string(),
323 projection: vec![],
324 filter: None,
325 ros_files: vec![],
326 }),
327 group_by: vec![0],
328 aggregates: vec![PhysicalAggExpr {
329 function: AggregateFunction::Sum,
330 input: 1,
331 alias: Some("partial_sum".to_string()),
332 }],
333 mode: AggregateMode::Partial,
334 }
335 }
336
337 #[test]
338 fn test_split_final_over_partial_agg() {
339 let plan = PhysicalPlan::HashAggregate {
340 input: Box::new(make_partial_agg()),
341 group_by: vec![0],
342 aggregates: vec![PhysicalAggExpr {
343 function: AggregateFunction::Sum,
344 input: 1,
345 alias: Some("total_sum".to_string()),
346 }],
347 mode: AggregateMode::Final,
348 };
349
350 let dag = FragmentSplitter::split(plan).unwrap();
351
352 assert!(
354 dag.coordinator_plan.is_some(),
355 "coordinator_plan should be Some"
356 );
357 let coord = dag.coordinator_plan.unwrap();
358
359 assert!(matches!(
361 coord,
362 PhysicalPlan::HashAggregate {
363 mode: AggregateMode::Final,
364 ..
365 }
366 ));
367
368 if let PhysicalPlan::HashAggregate { input, .. } = &coord {
370 assert!(
371 matches!(**input, PhysicalPlan::GridExchange { .. }),
372 "coordinator input should be GridExchange"
373 );
374 }
375
376 assert_eq!(dag.stages.len(), 1, "Should have 1 stage for aggregation");
378 let worker_plan = &dag.stages[0].plans[0];
379
380 if let PhysicalPlan::ShuffleWriter { input, .. } = worker_plan {
382 assert!(matches!(
383 **input,
384 PhysicalPlan::HashAggregate {
385 mode: AggregateMode::Partial,
386 ..
387 }
388 ));
389 } else {
390 panic!("Expected ShuffleWriter");
391 }
392 }
393
394 #[test]
395 fn test_no_split_simple_scan() {
396 let plan = PhysicalPlan::TableScan {
397 table: "T1".to_string(),
398 projection: vec![],
399 filter: None,
400 ros_files: vec![],
401 };
402
403 let dag = FragmentSplitter::split(plan).unwrap();
404 assert!(
405 dag.coordinator_plan.is_none(),
406 "simple scan should not split"
407 );
408 assert_eq!(dag.stages.len(), 1);
409 assert_eq!(dag.stages[0].plans.len(), 1);
410 }
411}