Skip to main content

datafusion_distributed/protobuf/
distributed_codec.rs

1use super::get_distributed_user_codecs;
2use crate::NetworkShuffleExec;
3use crate::common::{deserialize_uuid, serialize_uuid};
4use crate::execution_plans::{
5    BroadcastExec, ChildWeight, ChildrenIsolatorUnionExec, NetworkBroadcastExec,
6    NetworkCoalesceExec,
7};
8use crate::stage::{LocalStage, RemoteStage, Stage};
9use crate::worker::WorkerConnectionPool;
10use crate::{DistributedTaskContext, NetworkBoundary};
11use bytes::Bytes;
12use datafusion::arrow::datatypes::Schema;
13use datafusion::arrow::datatypes::SchemaRef;
14use datafusion::common::Result;
15use datafusion::error::DataFusionError;
16use datafusion::execution::TaskContext;
17use datafusion::physical_expr::EquivalenceProperties;
18use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
19use datafusion::physical_plan::union::UnionExec;
20use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
21use datafusion::prelude::SessionConfig;
22use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
23use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
24use datafusion_proto::physical_plan::{
25    ComposedPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec,
26    PhysicalPlanDecodeContext,
27};
28use datafusion_proto::protobuf;
29use datafusion_proto::protobuf::proto_error;
30use itertools::Itertools;
31use prost::Message;
32use std::sync::Arc;
33use url::Url;
34
35/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
36/// deserializing the custom ExecutionPlans in this project
37#[derive(Debug)]
38pub struct DistributedCodec;
39
40impl DistributedCodec {
41    pub fn new_combined_with_user(cfg: &SessionConfig) -> ComposedPhysicalExtensionCodec {
42        let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
43        codecs.extend(get_distributed_user_codecs(cfg));
44        ComposedPhysicalExtensionCodec::new(codecs)
45    }
46}
47
48impl PhysicalExtensionCodec for DistributedCodec {
49    fn try_decode(
50        &self,
51        buf: &[u8],
52        inputs: &[Arc<dyn ExecutionPlan>],
53        ctx: &TaskContext,
54    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
55        let DistributedExecProto {
56            node: Some(distributed_exec_node),
57        } = DistributedExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?
58        else {
59            return Err(proto_error(
60                "Expected DistributedExecNode in DistributedExecProto",
61            ));
62        };
63
64        fn parse_stage_proto(
65            proto: Option<StageProto>,
66            inputs: &[Arc<dyn ExecutionPlan>],
67        ) -> Result<Stage, DataFusionError> {
68            let Some(proto) = proto else {
69                return Err(proto_error("Empty StageProto"));
70            };
71            if let Some(input) = inputs.first().cloned() {
72                Ok(Stage::Local(LocalStage {
73                    query_id: deserialize_uuid(proto.query_id.as_ref())?,
74                    num: proto.num as usize,
75                    plan: input,
76                    tasks: proto.tasks.len(),
77                }))
78            } else {
79                let mut worker_urls = Vec::with_capacity(proto.tasks.len());
80                for task in proto.tasks {
81                    let Some(url_str) = task.url_str else {
82                        return Err(proto_error("Missing URL in task"));
83                    };
84                    let Ok(url) = Url::parse(&url_str) else {
85                        return Err(proto_error("Invalid URL in task"));
86                    };
87                    worker_urls.push(url);
88                }
89                Ok(Stage::Remote(RemoteStage {
90                    query_id: deserialize_uuid(proto.query_id.as_ref())?,
91                    num: proto.num as usize,
92                    workers: worker_urls,
93                }))
94            }
95        }
96
97        match distributed_exec_node {
98            DistributedExecNode::NetworkHashShuffle(NetworkShuffleExecProto {
99                schema,
100                partitioning,
101                input_stage,
102            }) => {
103                let schema: Schema = schema
104                    .as_ref()
105                    .map(|s| s.try_into())
106                    .ok_or(proto_error("NetworkShuffleExec is missing schema"))??;
107
108                let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
109                let partitioning = parse_protobuf_partitioning(
110                    partitioning.as_ref(),
111                    &decode_ctx,
112                    &schema,
113                    &DefaultPhysicalProtoConverter {},
114                )?
115                .ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
116
117                Ok(Arc::new(new_network_hash_shuffle_exec(
118                    partitioning,
119                    Arc::new(schema),
120                    parse_stage_proto(input_stage, inputs)?,
121                )))
122            }
123            DistributedExecNode::NetworkCoalesceTasks(NetworkCoalesceExecProto {
124                schema,
125                partitioning,
126                input_stage,
127            }) => {
128                let schema: Schema = schema
129                    .as_ref()
130                    .map(|s| s.try_into())
131                    .ok_or(proto_error("NetworkCoalesceExec is missing schema"))??;
132
133                let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
134                let partitioning = parse_protobuf_partitioning(
135                    partitioning.as_ref(),
136                    &decode_ctx,
137                    &schema,
138                    &DefaultPhysicalProtoConverter {},
139                )?
140                .ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
141
142                Ok(Arc::new(new_network_coalesce_tasks_exec(
143                    partitioning,
144                    Arc::new(schema),
145                    parse_stage_proto(input_stage, inputs)?,
146                )))
147            }
148            DistributedExecNode::NetworkBroadcast(NetworkBroadcastExecProto {
149                schema,
150                partitioning,
151                input_stage,
152            }) => {
153                let schema: Schema = schema
154                    .as_ref()
155                    .map(|s| s.try_into())
156                    .ok_or(proto_error("NetworkBroadcastExec is missing schema"))??;
157
158                let decode_ctx = PhysicalPlanDecodeContext::new(ctx, &DistributedCodec {});
159                let partitioning = parse_protobuf_partitioning(
160                    partitioning.as_ref(),
161                    &decode_ctx,
162                    &schema,
163                    &DefaultPhysicalProtoConverter {},
164                )?
165                .ok_or(proto_error("NetworkBroadcastExec is missing partitioning"))?;
166
167                Ok(Arc::new(new_network_broadcast_exec(
168                    partitioning,
169                    Arc::new(schema),
170                    parse_stage_proto(input_stage, inputs)?,
171                )))
172            }
173            DistributedExecNode::Broadcast(BroadcastExecProto {
174                consumer_task_count,
175            }) => {
176                if inputs.len() != 1 {
177                    return Err(proto_error(format!(
178                        "BroadcastExec expects exactly one child, got {}",
179                        inputs.len()
180                    )));
181                }
182
183                let child = inputs.first().unwrap();
184                Ok(Arc::new(BroadcastExec::new(
185                    child.clone(),
186                    consumer_task_count as usize,
187                )))
188            }
189            DistributedExecNode::ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto {
190                partition_count,
191                task_idx_map,
192                child_weights,
193            }) => {
194                // Building a UnionExec just to get the properties out of it is not the most
195                // efficient thing to do. However, it's the easiest way of getting the properties
196                // for the ChildrenIsolatorUnionExec without copy-pasting in this project
197                // all the machinery that builds them for UnionExec.
198                let mut properties = UnionExec::try_new(inputs.to_vec())?
199                    .properties()
200                    .as_ref()
201                    .clone();
202                properties.partitioning =
203                    Partitioning::UnknownPartitioning(partition_count as usize);
204
205                Ok(Arc::new(ChildrenIsolatorUnionExec {
206                    properties: Arc::new(properties),
207                    metrics: Default::default(),
208                    children: inputs.to_vec(),
209                    child_weights: child_weights
210                        .iter()
211                        .map(|cw| ChildWeight {
212                            weight: cw.weight,
213                            max: cw.max.map(|m| m as usize),
214                        })
215                        .collect(),
216                    task_idx_map: task_idx_map
217                        .iter()
218                        .map(|entry| {
219                            entry
220                                .child_ctx
221                                .iter()
222                                .map(|child_ctx| {
223                                    (
224                                        child_ctx.child_idx as usize,
225                                        DistributedTaskContext {
226                                            task_index: child_ctx.task_idx as usize,
227                                            task_count: child_ctx.task_count as usize,
228                                        },
229                                    )
230                                })
231                                .collect_vec()
232                        })
233                        .collect(),
234                }))
235            }
236        }
237    }
238
239    fn try_encode(&self, node: Arc<dyn ExecutionPlan>, buf: &mut Vec<u8>) -> Result<()> {
240        fn encode_stage_proto(stage: &Stage) -> Result<StageProto, DataFusionError> {
241            Ok(match stage {
242                Stage::Local(local) => StageProto {
243                    query_id: serialize_uuid(&local.query_id).into(),
244                    num: local.num as u64,
245                    tasks: vec![ExecutionTaskProto::default(); local.tasks],
246                },
247                Stage::Remote(remote) => {
248                    let mut tasks = Vec::with_capacity(remote.workers.len());
249                    for worker in &remote.workers {
250                        tasks.push(ExecutionTaskProto {
251                            url_str: Some(worker.to_string()),
252                        })
253                    }
254                    StageProto {
255                        query_id: serialize_uuid(&remote.query_id).into(),
256                        num: remote.num as u64,
257                        tasks,
258                    }
259                }
260            })
261        }
262
263        if let Some(node) = node.downcast_ref::<NetworkShuffleExec>() {
264            let inner = NetworkShuffleExecProto {
265                schema: Some(node.schema().try_into()?),
266                partitioning: Some(serialize_partitioning(
267                    node.properties().output_partitioning(),
268                    &DistributedCodec {},
269                    &DefaultPhysicalProtoConverter {},
270                )?),
271                input_stage: Some(encode_stage_proto(node.input_stage())?),
272            };
273
274            let wrapper = DistributedExecProto {
275                node: Some(DistributedExecNode::NetworkHashShuffle(inner)),
276            };
277
278            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
279        } else if let Some(node) = node.downcast_ref::<NetworkCoalesceExec>() {
280            let inner = NetworkCoalesceExecProto {
281                schema: Some(node.schema().try_into()?),
282                partitioning: Some(serialize_partitioning(
283                    node.properties().output_partitioning(),
284                    &DistributedCodec {},
285                    &DefaultPhysicalProtoConverter {},
286                )?),
287                input_stage: Some(encode_stage_proto(node.input_stage())?),
288            };
289
290            let wrapper = DistributedExecProto {
291                node: Some(DistributedExecNode::NetworkCoalesceTasks(inner)),
292            };
293
294            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
295        } else if let Some(node) = node.downcast_ref::<NetworkBroadcastExec>() {
296            let inner = NetworkBroadcastExecProto {
297                schema: Some(node.schema().try_into()?),
298                partitioning: Some(serialize_partitioning(
299                    node.properties().output_partitioning(),
300                    &DistributedCodec {},
301                    &DefaultPhysicalProtoConverter {},
302                )?),
303                input_stage: Some(encode_stage_proto(node.input_stage())?),
304            };
305
306            let wrapper = DistributedExecProto {
307                node: Some(DistributedExecNode::NetworkBroadcast(inner)),
308            };
309
310            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
311        } else if let Some(node) = node.downcast_ref::<BroadcastExec>() {
312            let inner = BroadcastExecProto {
313                consumer_task_count: node.consumer_task_count() as u64,
314            };
315
316            let wrapper = DistributedExecProto {
317                node: Some(DistributedExecNode::Broadcast(inner)),
318            };
319
320            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
321        } else if let Some(node) = node.downcast_ref::<ChildrenIsolatorUnionExec>() {
322            let inner = ChildrenIsolatorUnionExecProto {
323                partition_count: node.properties().output_partitioning().partition_count() as u64,
324                task_idx_map: node
325                    .task_idx_map
326                    .iter()
327                    .map(|v| TaskIdxMapEntryProto {
328                        child_ctx: v
329                            .iter()
330                            .map(|(child_idx, task_ctx)| ChildIdxWithTaskContextProto {
331                                child_idx: *child_idx as u64,
332                                task_idx: task_ctx.task_index as u64,
333                                task_count: task_ctx.task_count as u64,
334                            })
335                            .collect_vec(),
336                    })
337                    .collect_vec(),
338                child_weights: node
339                    .child_weights
340                    .iter()
341                    .map(|cw| ChildWeightProto {
342                        weight: cw.weight,
343                        max: cw.max.map(|m| m as u64),
344                    })
345                    .collect_vec(),
346            };
347
348            let wrapper = DistributedExecProto {
349                node: Some(DistributedExecNode::ChildrenIsolatorUnion(inner)),
350            };
351
352            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
353        } else {
354            Err(proto_error(format!("Unexpected plan {}", node.name())))
355        }
356    }
357}
358
359#[derive(Clone, PartialEq, ::prost::Message)]
360pub struct StageProto {
361    /// Our query id
362    #[prost(bytes, tag = "1")]
363    pub query_id: Bytes,
364    /// Our stage number
365    #[prost(uint64, tag = "2")]
366    pub num: u64,
367    /// Our tasks which tell us how finely grained to execute the partitions in
368    /// the plan
369    #[prost(message, repeated, tag = "3")]
370    pub tasks: Vec<ExecutionTaskProto>,
371}
372
373#[derive(Clone, PartialEq, ::prost::Message)]
374pub struct ExecutionTaskProto {
375    /// The url of the worker that will execute this task.  A None value is interpreted as
376    /// unassigned.
377    #[prost(string, optional, tag = "1")]
378    pub url_str: Option<String>,
379}
380
381#[derive(Clone, PartialEq, ::prost::Message)]
382pub struct DistributedExecProto {
383    #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6")]
384    pub node: Option<DistributedExecNode>,
385}
386
387#[derive(Clone, PartialEq, prost::Oneof)]
388pub enum DistributedExecNode {
389    #[prost(message, tag = "1")]
390    NetworkHashShuffle(NetworkShuffleExecProto),
391    #[prost(message, tag = "2")]
392    NetworkCoalesceTasks(NetworkCoalesceExecProto),
393    // reserved 3
394    #[prost(message, tag = "4")]
395    ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto),
396    #[prost(message, tag = "5")]
397    NetworkBroadcast(NetworkBroadcastExecProto),
398    #[prost(message, tag = "6")]
399    Broadcast(BroadcastExecProto),
400}
401
402/// Protobuf representation of the [NetworkShuffleExec] physical node. It serves as
403/// an intermediate format for serializing/deserializing [NetworkShuffleExec] nodes
404/// to send them over the wire.
405#[derive(Clone, PartialEq, ::prost::Message)]
406pub struct NetworkShuffleExecProto {
407    #[prost(message, optional, tag = "1")]
408    schema: Option<protobuf::Schema>,
409    #[prost(message, optional, tag = "2")]
410    partitioning: Option<protobuf::Partitioning>,
411    #[prost(message, optional, tag = "3")]
412    input_stage: Option<StageProto>,
413}
414
415#[derive(Clone, PartialEq, ::prost::Message)]
416pub struct ChildrenIsolatorUnionExecProto {
417    #[prost(uint64, tag = "1")]
418    partition_count: u64,
419    #[prost(message, repeated, tag = "2")]
420    task_idx_map: Vec<TaskIdxMapEntryProto>,
421    #[prost(message, repeated, tag = "3")]
422    child_weights: Vec<ChildWeightProto>,
423}
424
425#[derive(Clone, PartialEq, ::prost::Message)]
426pub struct ChildWeightProto {
427    #[prost(double, tag = "1")]
428    weight: f64,
429    #[prost(uint64, optional, tag = "2")]
430    max: Option<u64>,
431}
432
433#[derive(Clone, PartialEq, ::prost::Message)]
434pub struct TaskIdxMapEntryProto {
435    #[prost(message, repeated, tag = "1")]
436    child_ctx: Vec<ChildIdxWithTaskContextProto>,
437}
438
439#[derive(Clone, PartialEq, ::prost::Message)]
440pub struct ChildIdxWithTaskContextProto {
441    #[prost(uint64, tag = "1")]
442    child_idx: u64,
443    #[prost(uint64, tag = "2")]
444    task_idx: u64,
445    #[prost(uint64, tag = "3")]
446    task_count: u64,
447}
448
449fn new_network_hash_shuffle_exec(
450    partitioning: Partitioning,
451    schema: SchemaRef,
452    input_stage: Stage,
453) -> NetworkShuffleExec {
454    NetworkShuffleExec {
455        properties: Arc::new(PlanProperties::new(
456            EquivalenceProperties::new(schema),
457            partitioning,
458            EmissionType::Incremental,
459            Boundedness::Bounded,
460        )),
461        worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
462        input_stage,
463    }
464}
465
466/// Protobuf representation of the [NetworkShuffleExec] physical node. It serves as
467/// an intermediate format for serializing/deserializing [NetworkShuffleExec] nodes
468/// to send them over the wire.
469#[derive(Clone, PartialEq, ::prost::Message)]
470pub struct NetworkCoalesceExecProto {
471    #[prost(message, optional, tag = "1")]
472    schema: Option<protobuf::Schema>,
473    #[prost(message, optional, tag = "2")]
474    partitioning: Option<protobuf::Partitioning>,
475    #[prost(message, optional, tag = "3")]
476    input_stage: Option<StageProto>,
477}
478
479fn new_network_coalesce_tasks_exec(
480    partitioning: Partitioning,
481    schema: SchemaRef,
482    input_stage: Stage,
483) -> NetworkCoalesceExec {
484    NetworkCoalesceExec {
485        properties: Arc::new(PlanProperties::new(
486            EquivalenceProperties::new(schema),
487            partitioning,
488            EmissionType::Incremental,
489            Boundedness::Bounded,
490        )),
491        worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
492        input_stage,
493    }
494}
495
496#[derive(Clone, PartialEq, ::prost::Message)]
497pub struct NetworkBroadcastExecProto {
498    #[prost(message, optional, tag = "1")]
499    schema: Option<protobuf::Schema>,
500    #[prost(message, optional, tag = "2")]
501    partitioning: Option<protobuf::Partitioning>,
502    #[prost(message, optional, tag = "3")]
503    input_stage: Option<StageProto>,
504}
505
506#[derive(Clone, PartialEq, ::prost::Message)]
507pub struct BroadcastExecProto {
508    #[prost(uint64, tag = "1")]
509    pub consumer_task_count: u64,
510}
511
512fn new_network_broadcast_exec(
513    partitioning: Partitioning,
514    schema: SchemaRef,
515    input_stage: Stage,
516) -> NetworkBroadcastExec {
517    NetworkBroadcastExec {
518        properties: Arc::new(PlanProperties::new(
519            EquivalenceProperties::new(schema),
520            partitioning,
521            EmissionType::Incremental,
522            Boundedness::Bounded,
523        )),
524        worker_connections: WorkerConnectionPool::new(input_stage.task_count()),
525        input_stage,
526    }
527}
528
529#[cfg(test)]
530mod tests {
531    use super::*;
532    use datafusion::arrow::datatypes::{DataType, Field};
533    use datafusion::physical_expr::LexOrdering;
534    use datafusion::physical_plan::empty::EmptyExec;
535    use datafusion::prelude::SessionContext;
536    use datafusion::{
537        physical_expr::{Partitioning, PhysicalSortExpr, expressions::Column, expressions::col},
538        physical_plan::{ExecutionPlan, displayable, sorts::sort::SortExec, union::UnionExec},
539    };
540
541    fn empty_exec() -> Arc<dyn ExecutionPlan> {
542        Arc::new(EmptyExec::new(SchemaRef::new(Schema::empty())))
543    }
544
545    fn dummy_stage() -> Stage {
546        Stage::Remote(RemoteStage {
547            query_id: Default::default(),
548            num: 0,
549            workers: vec![],
550        })
551    }
552
553    fn dummy_stage_with_plan() -> Stage {
554        Stage::Local(LocalStage {
555            query_id: Default::default(),
556            num: 0,
557            plan: empty_exec(),
558            tasks: 1,
559        })
560    }
561
562    fn schema_i32(name: &str) -> Arc<Schema> {
563        Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
564    }
565
566    fn repr(plan: &Arc<dyn ExecutionPlan>) -> String {
567        displayable(plan.as_ref()).indent(true).to_string()
568    }
569
570    fn create_context() -> Arc<TaskContext> {
571        SessionContext::new().task_ctx()
572    }
573
574    #[test]
575    fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
576        let codec = DistributedCodec;
577        let ctx = create_context();
578
579        let schema = schema_i32("a");
580        let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
581        let plan: Arc<dyn ExecutionPlan> =
582            Arc::new(new_network_hash_shuffle_exec(part, schema, dummy_stage()));
583
584        let mut buf = Vec::new();
585        codec.try_encode(plan.clone(), &mut buf)?;
586
587        let decoded = codec.try_decode(&buf, &[], &ctx)?;
588        assert_eq!(repr(&plan), repr(&decoded));
589
590        Ok(())
591    }
592
593    #[test]
594    fn test_roundtrip_union() -> datafusion::common::Result<()> {
595        let codec = DistributedCodec;
596        let ctx = create_context();
597
598        let schema = schema_i32("c");
599        let left = Arc::new(new_network_hash_shuffle_exec(
600            Partitioning::RoundRobinBatch(2),
601            schema.clone(),
602            dummy_stage(),
603        ));
604        let right = Arc::new(new_network_hash_shuffle_exec(
605            Partitioning::RoundRobinBatch(2),
606            schema.clone(),
607            dummy_stage(),
608        ));
609
610        let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
611        let plan: Arc<dyn ExecutionPlan> =
612            Arc::new(NetworkCoalesceExec::try_new(union.clone(), 1, 1)?);
613
614        let mut buf = Vec::new();
615        codec.try_encode(plan.clone(), &mut buf)?;
616
617        let decoded = codec.try_decode(&buf, &[union], &ctx)?;
618        assert_eq!(repr(&plan), repr(&decoded));
619
620        Ok(())
621    }
622
623    #[test]
624    fn test_roundtrip_sort_flight() -> datafusion::common::Result<()> {
625        let codec = DistributedCodec;
626        let ctx = create_context();
627
628        let schema = schema_i32("d");
629        let flight = Arc::new(new_network_hash_shuffle_exec(
630            Partitioning::UnknownPartitioning(1),
631            schema.clone(),
632            dummy_stage(),
633        ));
634
635        let sort_expr = PhysicalSortExpr {
636            expr: col("d", &schema)?,
637            options: Default::default(),
638        };
639        let sort = Arc::new(SortExec::new(
640            LexOrdering::new(vec![sort_expr]).unwrap(),
641            flight.clone(),
642        ));
643
644        let plan: Arc<dyn ExecutionPlan> =
645            Arc::new(NetworkCoalesceExec::try_new(sort.clone(), 1, 1)?);
646
647        let mut buf = Vec::new();
648        codec.try_encode(plan.clone(), &mut buf)?;
649
650        let decoded = codec.try_decode(&buf, &[sort], &ctx)?;
651        assert_eq!(repr(&plan), repr(&decoded));
652
653        Ok(())
654    }
655
656    #[test]
657    fn test_roundtrip_single_flight_coalesce() -> datafusion::common::Result<()> {
658        let codec = DistributedCodec;
659        let ctx = create_context();
660
661        let schema = schema_i32("e");
662        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
663            Partitioning::RoundRobinBatch(3),
664            schema,
665            dummy_stage(),
666        ));
667
668        let mut buf = Vec::new();
669        codec.try_encode(plan.clone(), &mut buf)?;
670
671        let decoded = codec.try_decode(&buf, &[], &ctx)?;
672        assert_eq!(repr(&plan), repr(&decoded));
673
674        Ok(())
675    }
676
677    #[test]
678    fn test_roundtrip_single_flight_with_plan() -> datafusion::common::Result<()> {
679        let codec = DistributedCodec;
680        let ctx = create_context();
681
682        let schema = schema_i32("a");
683        let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
684        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_hash_shuffle_exec(
685            part,
686            schema,
687            dummy_stage_with_plan(),
688        ));
689
690        let mut buf = Vec::new();
691        codec.try_encode(plan.clone(), &mut buf)?;
692
693        let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
694        assert_eq!(repr(&plan), repr(&decoded));
695
696        Ok(())
697    }
698
699    #[test]
700    fn test_roundtrip_single_flight_coalesce_with_plan() -> datafusion::common::Result<()> {
701        let codec = DistributedCodec;
702        let ctx = create_context();
703
704        let schema = schema_i32("e");
705        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
706            Partitioning::RoundRobinBatch(3),
707            schema,
708            dummy_stage_with_plan(),
709        ));
710
711        let mut buf = Vec::new();
712        codec.try_encode(plan.clone(), &mut buf)?;
713
714        let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
715        assert_eq!(repr(&plan), repr(&decoded));
716
717        Ok(())
718    }
719
720    #[test]
721    fn test_roundtrip_flight_coalesce() -> datafusion::common::Result<()> {
722        let codec = DistributedCodec;
723        let ctx = create_context();
724
725        let schema = schema_i32("f");
726        let flight = Arc::new(new_network_coalesce_tasks_exec(
727            Partitioning::UnknownPartitioning(1),
728            schema,
729            dummy_stage(),
730        ));
731
732        let plan: Arc<dyn ExecutionPlan> =
733            Arc::new(NetworkCoalesceExec::try_new(flight.clone(), 1, 1)?);
734
735        let mut buf = Vec::new();
736        codec.try_encode(plan.clone(), &mut buf)?;
737
738        let decoded = codec.try_decode(&buf, &[flight], &ctx)?;
739        assert_eq!(repr(&plan), repr(&decoded));
740
741        Ok(())
742    }
743
744    #[test]
745    fn test_roundtrip_union_coalesce() -> datafusion::common::Result<()> {
746        let codec = DistributedCodec;
747        let ctx = create_context();
748
749        let schema = schema_i32("g");
750        let left = Arc::new(new_network_coalesce_tasks_exec(
751            Partitioning::RoundRobinBatch(2),
752            schema.clone(),
753            dummy_stage(),
754        ));
755        let right = Arc::new(new_network_coalesce_tasks_exec(
756            Partitioning::RoundRobinBatch(2),
757            schema.clone(),
758            dummy_stage(),
759        ));
760
761        let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
762        let plan: Arc<dyn ExecutionPlan> =
763            Arc::new(NetworkCoalesceExec::try_new(union.clone(), 1, 1)?);
764
765        let mut buf = Vec::new();
766        codec.try_encode(plan.clone(), &mut buf)?;
767
768        let decoded = codec.try_decode(&buf, &[union], &ctx)?;
769        assert_eq!(repr(&plan), repr(&decoded));
770
771        Ok(())
772    }
773
774    #[test]
775    fn test_roundtrip_children_isolator_union() -> datafusion::common::Result<()> {
776        let codec = DistributedCodec;
777        let ctx = create_context();
778
779        let schema = schema_i32("h");
780        let left = Arc::new(new_network_hash_shuffle_exec(
781            Partitioning::RoundRobinBatch(2),
782            schema.clone(),
783            dummy_stage(),
784        )) as Arc<dyn ExecutionPlan>;
785        let right = Arc::new(new_network_hash_shuffle_exec(
786            Partitioning::RoundRobinBatch(2),
787            schema.clone(),
788            dummy_stage(),
789        )) as Arc<dyn ExecutionPlan>;
790
791        let plan: Arc<dyn ExecutionPlan> =
792            Arc::new(ChildrenIsolatorUnionExec::from_children_and_weights(
793                vec![left.clone(), right.clone()],
794                vec![ChildWeight::desired(3.0), ChildWeight::maximum(1)],
795                4,
796            )?);
797
798        let mut buf = Vec::new();
799        codec.try_encode(plan.clone(), &mut buf)?;
800
801        let decoded = codec.try_decode(&buf, &[left, right], &ctx)?;
802        assert_eq!(repr(&plan), repr(&decoded));
803
804        Ok(())
805    }
806}