1use super::get_distributed_user_codecs;
2use crate::NetworkShuffleExec;
3use crate::common::{deserialize_uuid, serialize_uuid};
4use crate::execution_plans::{
5 BroadcastExec, ChildWeight, ChildrenIsolatorUnionExec, NetworkBroadcastExec,
6 NetworkCoalesceExec,
7};
8use crate::stage::{LocalStage, RemoteStage, Stage};
9use crate::worker::WorkerConnectionPool;
10use crate::{DistributedTaskContext, NetworkBoundary};
11use bytes::Bytes;
12use datafusion::arrow::datatypes::Schema;
13use datafusion::arrow::datatypes::SchemaRef;
14use datafusion::common::Result;
15use datafusion::error::DataFusionError;
16use datafusion::execution::TaskContext;
17use datafusion::physical_expr::EquivalenceProperties;
18use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
19use datafusion::physical_plan::union::UnionExec;
20use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
21use datafusion::prelude::SessionConfig;
22use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
23use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
24use datafusion_proto::physical_plan::{
25 ComposedPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec,
26 PhysicalPlanDecodeContext,
27};
28use datafusion_proto::protobuf;
29use datafusion_proto::protobuf::proto_error;
30use itertools::Itertools;
31use prost::Message;
32use std::sync::Arc;
33use url::Url;
34
35#[derive(Debug)]
38pub struct DistributedCodec;
39
40impl DistributedCodec {
41 pub fn new_combined_with_user(cfg: &SessionConfig) -> ComposedPhysicalExtensionCodec {
42 let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
43 codecs.extend(get_distributed_user_codecs(cfg));
44 ComposedPhysicalExtensionCodec::new(codecs)
45 }
46}
47
48impl PhysicalExtensionCodec for DistributedCodec {
49 fn try_decode(
50 &self,
51 buf: &[u8],
52 inputs: &[Arc<dyn ExecutionPlan>],
53 ctx: &TaskContext,
54 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
55 let DistributedExecProto {
56 node: Some(distributed_exec_node),
57 } = DistributedExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?
58 else {
59 return Err(proto_error(
60 "Expected DistributedExecNode in DistributedExecProto",
61 ));
62 };
63
64 fn parse_stage_proto(
65 proto: Option<StageProto>,
66 inputs: &[Arc<dyn ExecutionPlan>],
67 ) -> Result<Stage, DataFusionError> {
68 let Some(proto) = proto else {
69 return Err(proto_error("Empty StageProto"));
70 };
71 if let Some(input) = inputs.first().cloned() {
72 Ok(Stage::Local(LocalStage {
73 query_id: deserialize_uuid(proto.query_id.as_ref())?,
74 num: proto.num as usize,
75 plan: input,
76 tasks: proto.tasks.len(),
77 }))
78 } else {
79 let mut worker_urls = Vec::with_capacity(proto.tasks.len());
80 for task in proto.tasks {
81 let Some(url_str) = task.url_str else {
82 return Err(proto_error("Missing URL in task"));
83 };
84 let Ok(url) = Url::parse(&url_str) else {
85 return Err(proto_error("Invalid URL in task"));
86 };
87 worker_urls.push(url);
88 }
89 Ok(Stage::Remote(RemoteStage {
90 query_id: deserialize_uuid(proto.query_id.as_ref())?,
91 num: proto.num as usize,
92 workers: worker_urls,
93 }))
94 }
95 }
96
97 match distributed_exec_node {
98 DistributedExecNode::NetworkHashShuffle(NetworkShuffleExecProto {
99 schema,
100 partitioning,
101 input_stage,
102 }) => {
103 let schema: Schema = schema
104 .as_ref()
105 .map(|s| s.try_into())
106 .ok_or(proto_error("NetworkShuffleExec is missing schema"))??;
107
108 let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
109 let partitioning = parse_protobuf_partitioning(
110 partitioning.as_ref(),
111 &decode_ctx,
112 &schema,
113 &DefaultPhysicalProtoConverter {},
114 )?
115 .ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
116
117 Ok(Arc::new(new_network_hash_shuffle_exec(
118 partitioning,
119 Arc::new(schema),
120 parse_stage_proto(input_stage, inputs)?,
121 )))
122 }
123 DistributedExecNode::NetworkCoalesceTasks(NetworkCoalesceExecProto {
124 schema,
125 partitioning,
126 input_stage,
127 }) => {
128 let schema: Schema = schema
129 .as_ref()
130 .map(|s| s.try_into())
131 .ok_or(proto_error("NetworkCoalesceExec is missing schema"))??;
132
133 let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
134 let partitioning = parse_protobuf_partitioning(
135 partitioning.as_ref(),
136 &decode_ctx,
137 &schema,
138 &DefaultPhysicalProtoConverter {},
139 )?
140 .ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
141
142 Ok(Arc::new(new_network_coalesce_tasks_exec(
143 partitioning,
144 Arc::new(schema),
145 parse_stage_proto(input_stage, inputs)?,
146 )))
147 }
148 DistributedExecNode::NetworkBroadcast(NetworkBroadcastExecProto {
149 schema,
150 partitioning,
151 input_stage,
152 }) => {
153 let schema: Schema = schema
154 .as_ref()
155 .map(|s| s.try_into())
156 .ok_or(proto_error("NetworkBroadcastExec is missing schema"))??;
157
158 let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
159 let partitioning = parse_protobuf_partitioning(
160 partitioning.as_ref(),
161 &decode_ctx,
162 &schema,
163 &DefaultPhysicalProtoConverter {},
164 )?
165 .ok_or(proto_error("NetworkBroadcastExec is missing partitioning"))?;
166
167 Ok(Arc::new(new_network_broadcast_exec(
168 partitioning,
169 Arc::new(schema),
170 parse_stage_proto(input_stage, inputs)?,
171 )))
172 }
173 DistributedExecNode::Broadcast(BroadcastExecProto {
174 consumer_task_count,
175 }) => {
176 if inputs.len() != 1 {
177 return Err(proto_error(format!(
178 "BroadcastExec expects exactly one child, got {}",
179 inputs.len()
180 )));
181 }
182
183 let child = inputs.first().unwrap();
184 Ok(Arc::new(BroadcastExec::new(
185 child.clone(),
186 consumer_task_count as usize,
187 )))
188 }
189 DistributedExecNode::ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto {
190 partition_count,
191 task_idx_map,
192 child_weights,
193 }) => {
194 let mut properties = UnionExec::try_new(inputs.to_vec())?
199 .properties()
200 .as_ref()
201 .clone();
202 properties.partitioning =
203 Partitioning::UnknownPartitioning(partition_count as usize);
204
205 Ok(Arc::new(ChildrenIsolatorUnionExec {
206 properties: Arc::new(properties),
207 metrics: Default::default(),
208 children: inputs.to_vec(),
209 child_weights: child_weights
210 .iter()
211 .map(|cw| ChildWeight {
212 weight: cw.weight,
213 max: cw.max.map(|m| m as usize),
214 })
215 .collect(),
216 task_idx_map: task_idx_map
217 .iter()
218 .map(|entry| {
219 entry
220 .child_ctx
221 .iter()
222 .map(|child_ctx| {
223 (
224 child_ctx.child_idx as usize,
225 DistributedTaskContext {
226 task_index: child_ctx.task_idx as usize,
227 task_count: child_ctx.task_count as usize,
228 },
229 )
230 })
231 .collect_vec()
232 })
233 .collect(),
234 }))
235 }
236 }
237 }
238
239 fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
240 fn encode_stage_proto(stage: &Stage) -> Result<StageProto, DataFusionError> {
241 Ok(match stage {
242 Stage::Local(local) => StageProto {
243 query_id: serialize_uuid(&local.query_id).into(),
244 num: local.num as u64,
245 tasks: vec![ExecutionTaskProto::default(); local.tasks],
246 },
247 Stage::Remote(remote) => {
248 let mut tasks = Vec::with_capacity(remote.workers.len());
249 for worker in &remote.workers {
250 tasks.push(ExecutionTaskProto {
251 url_str: Some(worker.to_string()),
252 })
253 }
254 StageProto {
255 query_id: serialize_uuid(&remote.query_id).into(),
256 num: remote.num as u64,
257 tasks,
258 }
259 }
260 })
261 }
262
263 if let Some(node) = node.downcast_ref::<NetworkShuffleExec>() {
264 let inner = NetworkShuffleExecProto {
265 schema: Some(node.schema().try_into()?),
266 partitioning: Some(serialize_partitioning(
267 node.properties().output_partitioning(),
268 &DistributedCodec {},
269 &DefaultPhysicalProtoConverter {},
270 )?),
271 input_stage: Some(encode_stage_proto(node.input_stage())?),
272 };
273
274 let wrapper = DistributedExecProto {
275 node: Some(DistributedExecNode::NetworkHashShuffle(inner)),
276 };
277
278 wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
279 } else if let Some(node) = node.downcast_ref::<NetworkCoalesceExec>() {
280 let inner = NetworkCoalesceExecProto {
281 schema: Some(node.schema().try_into()?),
282 partitioning: Some(serialize_partitioning(
283 node.properties().output_partitioning(),
284 &DistributedCodec {},
285 &DefaultPhysicalProtoConverter {},
286 )?),
287 input_stage: Some(encode_stage_proto(node.input_stage())?),
288 };
289
290 let wrapper = DistributedExecProto {
291 node: Some(DistributedExecNode::NetworkCoalesceTasks(inner)),
292 };
293
294 wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
295 } else if let Some(node) = node.downcast_ref::<NetworkBroadcastExec>() {
296 let inner = NetworkBroadcastExecProto {
297 schema: Some(node.schema().try_into()?),
298 partitioning: Some(serialize_partitioning(
299 node.properties().output_partitioning(),
300 &DistributedCodec {},
301 &DefaultPhysicalProtoConverter {},
302 )?),
303 input_stage: Some(encode_stage_proto(node.input_stage())?),
304 };
305
306 let wrapper = DistributedExecProto {
307 node: Some(DistributedExecNode::NetworkBroadcast(inner)),
308 };
309
310 wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
311 } else if let Some(node) = node.downcast_ref::<BroadcastExec>() {
312 let inner = BroadcastExecProto {
313 consumer_task_count: node.consumer_task_count() as u64,
314 };
315
316 let wrapper = DistributedExecProto {
317 node: Some(DistributedExecNode::Broadcast(inner)),
318 };
319
320 wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
321 } else if let Some(node) = node.downcast_ref::<ChildrenIsolatorUnionExec>() {
322 let inner = ChildrenIsolatorUnionExecProto {
323 partition_count: node.properties().output_partitioning().partition_count() as u64,
324 task_idx_map: node
325 .task_idx_map
326 .iter()
327 .map(|v| TaskIdxMapEntryProto {
328 child_ctx: v
329 .iter()
330 .map(|(child_idx, task_ctx)| ChildIdxWithTaskContextProto {
331 child_idx: *child_idx as u64,
332 task_idx: task_ctx.task_index as u64,
333 task_count: task_ctx.task_count as u64,
334 })
335 .collect_vec(),
336 })
337 .collect_vec(),
338 child_weights: node
339 .child_weights
340 .iter()
341 .map(|cw| ChildWeightProto {
342 weight: cw.weight,
343 max: cw.max.map(|m| m as u64),
344 })
345 .collect_vec(),
346 };
347
348 let wrapper = DistributedExecProto {
349 node: Some(DistributedExecNode::ChildrenIsolatorUnion(inner)),
350 };
351
352 wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
353 } else {
354 Err(proto_error(format!("Unexpected plan {}", node.name())))
355 }
356 }
357}
358
359#[derive(Clone, PartialEq, ::prost::Message)]
360pub struct StageProto {
361 #[prost(bytes, tag = "1")]
363 pub query_id: Bytes,
364 #[prost(uint64, tag = "2")]
366 pub num: u64,
367 #[prost(message, repeated, tag = "3")]
370 pub tasks: Vec<ExecutionTaskProto>,
371}
372
373#[derive(Clone, PartialEq, ::prost::Message)]
374pub struct ExecutionTaskProto {
375 #[prost(string, optional, tag = "1")]
378 pub url_str: Option<String>,
379}
380
381#[derive(Clone, PartialEq, ::prost::Message)]
382pub struct DistributedExecProto {
383 #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6")]
384 pub node: Option<DistributedExecNode>,
385}
386
387#[derive(Clone, PartialEq, prost::Oneof)]
388pub enum DistributedExecNode {
389 #[prost(message, tag = "1")]
390 NetworkHashShuffle(NetworkShuffleExecProto),
391 #[prost(message, tag = "2")]
392 NetworkCoalesceTasks(NetworkCoalesceExecProto),
393 #[prost(message, tag = "4")]
395 ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto),
396 #[prost(message, tag = "5")]
397 NetworkBroadcast(NetworkBroadcastExecProto),
398 #[prost(message, tag = "6")]
399 Broadcast(BroadcastExecProto),
400}
401
402#[derive(Clone, PartialEq, ::prost::Message)]
406pub struct NetworkShuffleExecProto {
407 #[prost(message, optional, tag = "1")]
408 schema: Option<protobuf::Schema>,
409 #[prost(message, optional, tag = "2")]
410 partitioning: Option<protobuf::Partitioning>,
411 #[prost(message, optional, tag = "3")]
412 input_stage: Option<StageProto>,
413}
414
415#[derive(Clone, PartialEq, ::prost::Message)]
416pub struct ChildrenIsolatorUnionExecProto {
417 #[prost(uint64, tag = "1")]
418 partition_count: u64,
419 #[prost(message, repeated, tag = "2")]
420 task_idx_map: Vec<TaskIdxMapEntryProto>,
421 #[prost(message, repeated, tag = "3")]
422 child_weights: Vec<ChildWeightProto>,
423}
424
425#[derive(Clone, PartialEq, ::prost::Message)]
426pub struct ChildWeightProto {
427 #[prost(double, tag = "1")]
428 weight: f64,
429 #[prost(uint64, optional, tag = "2")]
430 max: Option<u64>,
431}
432
433#[derive(Clone, PartialEq, ::prost::Message)]
434pub struct TaskIdxMapEntryProto {
435 #[prost(message, repeated, tag = "1")]
436 child_ctx: Vec<ChildIdxWithTaskContextProto>,
437}
438
439#[derive(Clone, PartialEq, ::prost::Message)]
440pub struct ChildIdxWithTaskContextProto {
441 #[prost(uint64, tag = "1")]
442 child_idx: u64,
443 #[prost(uint64, tag = "2")]
444 task_idx: u64,
445 #[prost(uint64, tag = "3")]
446 task_count: u64,
447}
448
449fn new_network_hash_shuffle_exec(
450 partitioning: Partitioning,
451 schema: SchemaRef,
452 input_stage: Stage,
453) -> NetworkShuffleExec {
454 NetworkShuffleExec {
455 properties: Arc::new(PlanProperties::new(
456 EquivalenceProperties::new(schema),
457 partitioning,
458 EmissionType::Incremental,
459 Boundedness::Bounded,
460 )),
461 worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
462 input_stage,
463 }
464}
465
466#[derive(Clone, PartialEq, ::prost::Message)]
470pub struct NetworkCoalesceExecProto {
471 #[prost(message, optional, tag = "1")]
472 schema: Option<protobuf::Schema>,
473 #[prost(message, optional, tag = "2")]
474 partitioning: Option<protobuf::Partitioning>,
475 #[prost(message, optional, tag = "3")]
476 input_stage: Option<StageProto>,
477}
478
479fn new_network_coalesce_tasks_exec(
480 partitioning: Partitioning,
481 schema: SchemaRef,
482 input_stage: Stage,
483) -> NetworkCoalesceExec {
484 NetworkCoalesceExec {
485 properties: Arc::new(PlanProperties::new(
486 EquivalenceProperties::new(schema),
487 partitioning,
488 EmissionType::Incremental,
489 Boundedness::Bounded,
490 )),
491 worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
492 input_stage,
493 }
494}
495
496#[derive(Clone, PartialEq, ::prost::Message)]
497pub struct NetworkBroadcastExecProto {
498 #[prost(message, optional, tag = "1")]
499 schema: Option<protobuf::Schema>,
500 #[prost(message, optional, tag = "2")]
501 partitioning: Option<protobuf::Partitioning>,
502 #[prost(message, optional, tag = "3")]
503 input_stage: Option<StageProto>,
504}
505
506#[derive(Clone, PartialEq, ::prost::Message)]
507pub struct BroadcastExecProto {
508 #[prost(uint64, tag = "1")]
509 pub consumer_task_count: u64,
510}
511
512fn new_network_broadcast_exec(
513 partitioning: Partitioning,
514 schema: SchemaRef,
515 input_stage: Stage,
516) -> NetworkBroadcastExec {
517 NetworkBroadcastExec {
518 properties: Arc::new(PlanProperties::new(
519 EquivalenceProperties::new(schema),
520 partitioning,
521 EmissionType::Incremental,
522 Boundedness::Bounded,
523 )),
524 worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
525 input_stage,
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::*;
532 use datafusion::arrow::datatypes::{DataType, Field};
533 use datafusion::physical_expr::LexOrdering;
534 use datafusion::physical_plan::empty::EmptyExec;
535 use datafusion::prelude::SessionContext;
536 use datafusion::{
537 physical_expr::{Partitioning, PhysicalSortExpr, expressions::Column, expressions::col},
538 physical_plan::{ExecutionPlan, displayable, sorts::sort::SortExec, union::UnionExec},
539 };
540
541 fn empty_exec() -> Arc<dyn ExecutionPlan> {
542 Arc::new(EmptyExec::new(SchemaRef::new(Schema::empty())))
543 }
544
545 fn dummy_stage() -> Stage {
546 Stage::Remote(RemoteStage {
547 query_id: Default::default(),
548 num: 0,
549 workers: vec![],
550 })
551 }
552
553 fn dummy_stage_with_plan() -> Stage {
554 Stage::Local(LocalStage {
555 query_id: Default::default(),
556 num: 0,
557 plan: empty_exec(),
558 tasks: 1,
559 })
560 }
561
562 fn schema_i32(name: &str) -> Arc<Schema> {
563 Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
564 }
565
566 fn repr(plan: &Arc<dyn ExecutionPlan>) -> String {
567 displayable(plan.as_ref()).indent(true).to_string()
568 }
569
570 fn create_context() -> Arc<TaskContext> {
571 SessionContext::new().task_ctx()
572 }
573
574 #[test]
575 fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
576 let codec = DistributedCodec;
577 let ctx = create_context();
578
579 let schema = schema_i32("a");
580 let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
581 let plan: Arc<dyn ExecutionPlan> =
582 Arc::new(new_network_hash_shuffle_exec(part, schema, dummy_stage()));
583
584 let mut buf = Vec::new();
585 codec.try_encode(plan.clone(), &mut buf)?;
586
587 let decoded = codec.try_decode(&buf, &[], &ctx)?;
588 assert_eq!(repr(&plan), repr(&decoded));
589
590 Ok(())
591 }
592
593 #[test]
594 fn test_roundtrip_union() -> datafusion::common::Result<()> {
595 let codec = DistributedCodec;
596 let ctx = create_context();
597
598 let schema = schema_i32("c");
599 let left = Arc::new(new_network_hash_shuffle_exec(
600 Partitioning::RoundRobinBatch(2),
601 schema.clone(),
602 dummy_stage(),
603 ));
604 let right = Arc::new(new_network_hash_shuffle_exec(
605 Partitioning::RoundRobinBatch(2),
606 schema.clone(),
607 dummy_stage(),
608 ));
609
610 let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
611 let plan: Arc<dyn ExecutionPlan> =
612 Arc::new(NetworkCoalesceExec::try_new(union.clone(), 1, 1)?);
613
614 let mut buf = Vec::new();
615 codec.try_encode(plan.clone(), &mut buf)?;
616
617 let decoded = codec.try_decode(&buf, &[union], &ctx)?;
618 assert_eq!(repr(&plan), repr(&decoded));
619
620 Ok(())
621 }
622
623 #[test]
624 fn test_roundtrip_sort_flight() -> datafusion::common::Result<()> {
625 let codec = DistributedCodec;
626 let ctx = create_context();
627
628 let schema = schema_i32("d");
629 let flight = Arc::new(new_network_hash_shuffle_exec(
630 Partitioning::UnknownPartitioning(1),
631 schema.clone(),
632 dummy_stage(),
633 ));
634
635 let sort_expr = PhysicalSortExpr {
636 expr: col("d", &schema)?,
637 options: Default::default(),
638 };
639 let sort = Arc::new(SortExec::new(
640 LexOrdering::new(vec![sort_expr]).unwrap(),
641 flight.clone(),
642 ));
643
644 let plan: Arc<dyn ExecutionPlan> =
645 Arc::new(NetworkCoalesceExec::try_new(sort.clone(), 1, 1)?);
646
647 let mut buf = Vec::new();
648 codec.try_encode(plan.clone(), &mut buf)?;
649
650 let decoded = codec.try_decode(&buf, &[sort], &ctx)?;
651 assert_eq!(repr(&plan), repr(&decoded));
652
653 Ok(())
654 }
655
656 #[test]
657 fn test_roundtrip_single_flight_coalesce() -> datafusion::common::Result<()> {
658 let codec = DistributedCodec;
659 let ctx = create_context();
660
661 let schema = schema_i32("e");
662 let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
663 Partitioning::RoundRobinBatch(3),
664 schema,
665 dummy_stage(),
666 ));
667
668 let mut buf = Vec::new();
669 codec.try_encode(plan.clone(), &mut buf)?;
670
671 let decoded = codec.try_decode(&buf, &[], &ctx)?;
672 assert_eq!(repr(&plan), repr(&decoded));
673
674 Ok(())
675 }
676
677 #[test]
678 fn test_roundtrip_single_flight_with_plan() -> datafusion::common::Result<()> {
679 let codec = DistributedCodec;
680 let ctx = create_context();
681
682 let schema = schema_i32("a");
683 let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
684 let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_hash_shuffle_exec(
685 part,
686 schema,
687 dummy_stage_with_plan(),
688 ));
689
690 let mut buf = Vec::new();
691 codec.try_encode(plan.clone(), &mut buf)?;
692
693 let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
694 assert_eq!(repr(&plan), repr(&decoded));
695
696 Ok(())
697 }
698
699 #[test]
700 fn test_roundtrip_single_flight_coalesce_with_plan() -> datafusion::common::Result<()> {
701 let codec = DistributedCodec;
702 let ctx = create_context();
703
704 let schema = schema_i32("e");
705 let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
706 Partitioning::RoundRobinBatch(3),
707 schema,
708 dummy_stage_with_plan(),
709 ));
710
711 let mut buf = Vec::new();
712 codec.try_encode(plan.clone(), &mut buf)?;
713
714 let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
715 assert_eq!(repr(&plan), repr(&decoded));
716
717 Ok(())
718 }
719
720 #[test]
721 fn test_roundtrip_flight_coalesce() -> datafusion::common::Result<()> {
722 let codec = DistributedCodec;
723 let ctx = create_context();
724
725 let schema = schema_i32("f");
726 let flight = Arc::new(new_network_coalesce_tasks_exec(
727 Partitioning::UnknownPartitioning(1),
728 schema,
729 dummy_stage(),
730 ));
731
732 let plan: Arc<dyn ExecutionPlan> =
733 Arc::new(NetworkCoalesceExec::try_new(flight.clone(), 1, 1)?);
734
735 let mut buf = Vec::new();
736 codec.try_encode(plan.clone(), &mut buf)?;
737
738 let decoded = codec.try_decode(&buf, &[flight], &ctx)?;
739 assert_eq!(repr(&plan), repr(&decoded));
740
741 Ok(())
742 }
743
744 #[test]
745 fn test_roundtrip_union_coalesce() -> datafusion::common::Result<()> {
746 let codec = DistributedCodec;
747 let ctx = create_context();
748
749 let schema = schema_i32("g");
750 let left = Arc::new(new_network_coalesce_tasks_exec(
751 Partitioning::RoundRobinBatch(2),
752 schema.clone(),
753 dummy_stage(),
754 ));
755 let right = Arc::new(new_network_coalesce_tasks_exec(
756 Partitioning::RoundRobinBatch(2),
757 schema.clone(),
758 dummy_stage(),
759 ));
760
761 let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
762 let plan: Arc<dyn ExecutionPlan> =
763 Arc::new(NetworkCoalesceExec::try_new(union.clone(), 1, 1)?);
764
765 let mut buf = Vec::new();
766 codec.try_encode(plan.clone(), &mut buf)?;
767
768 let decoded = codec.try_decode(&buf, &[union], &ctx)?;
769 assert_eq!(repr(&plan), repr(&decoded));
770
771 Ok(())
772 }
773
774 #[test]
775 fn test_roundtrip_children_isolator_union() -> datafusion::common::Result<()> {
776 let codec = DistributedCodec;
777 let ctx = create_context();
778
779 let schema = schema_i32("h");
780 let left = Arc::new(new_network_hash_shuffle_exec(
781 Partitioning::RoundRobinBatch(2),
782 schema.clone(),
783 dummy_stage(),
784 )) as Arc<dyn ExecutionPlan>;
785 let right = Arc::new(new_network_hash_shuffle_exec(
786 Partitioning::RoundRobinBatch(2),
787 schema.clone(),
788 dummy_stage(),
789 )) as Arc<dyn ExecutionPlan>;
790
791 let plan: Arc<dyn ExecutionPlan> =
792 Arc::new(ChildrenIsolatorUnionExec::from_children_and_weights(
793 vec![left.clone(), right.clone()],
794 vec![ChildWeight::desired(3.0), ChildWeight::maximum(1)],
795 4,
796 )?);
797
798 let mut buf = Vec::new();
799 codec.try_encode(plan.clone(), &mut buf)?;
800
801 let decoded = codec.try_decode(&buf, &[left, right], &ctx)?;
802 assert_eq!(repr(&plan), repr(&decoded));
803
804 Ok(())
805 }
806}