Skip to main content

datafusion_distributed/protobuf/
distributed_codec.rs

1use super::get_distributed_user_codecs;
2use crate::execution_plans::{
3    BroadcastExec, ChildrenIsolatorUnionExec, NetworkBroadcastExec, NetworkCoalesceExec,
4};
5use crate::stage::{ExecutionTask, Stage};
6use crate::worker::WorkerConnectionPool;
7use crate::{DistributedTaskContext, NetworkBoundary};
8use crate::{NetworkShuffleExec, PartitionIsolatorExec};
9use bytes::Bytes;
10use datafusion::arrow::datatypes::Schema;
11use datafusion::arrow::datatypes::SchemaRef;
12use datafusion::common::{Result, internal_datafusion_err};
13use datafusion::error::DataFusionError;
14use datafusion::execution::TaskContext;
15use datafusion::physical_expr::EquivalenceProperties;
16use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
17use datafusion::physical_plan::union::UnionExec;
18use datafusion::physical_plan::{ExecutionPlan, Partitioning, PlanProperties};
19use datafusion::prelude::SessionConfig;
20use datafusion_proto::physical_plan::from_proto::parse_protobuf_partitioning;
21use datafusion_proto::physical_plan::to_proto::serialize_partitioning;
22use datafusion_proto::physical_plan::{
23    ComposedPhysicalExtensionCodec, DefaultPhysicalProtoConverter, PhysicalExtensionCodec,
24};
25use datafusion_proto::protobuf;
26use datafusion_proto::protobuf::proto_error;
27use itertools::Itertools;
28use prost::Message;
29use std::sync::Arc;
30use url::Url;
31
32/// DataFusion [PhysicalExtensionCodec] implementation that allows serializing and
33/// deserializing the custom ExecutionPlans in this project
34#[derive(Debug)]
35pub struct DistributedCodec;
36
37impl DistributedCodec {
38    pub fn new_combined_with_user(cfg: &SessionConfig) -> impl PhysicalExtensionCodec + use<> {
39        let mut codecs: Vec<Arc<dyn PhysicalExtensionCodec>> = vec![Arc::new(DistributedCodec {})];
40        codecs.extend(get_distributed_user_codecs(cfg));
41        ComposedPhysicalExtensionCodec::new(codecs)
42    }
43}
44
45impl PhysicalExtensionCodec for DistributedCodec {
46    fn try_decode(
47        &self,
48        buf: &[u8],
49        inputs: &[Arc<dyn ExecutionPlan>],
50        ctx: &TaskContext,
51    ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
52        let DistributedExecProto {
53            node: Some(distributed_exec_node),
54        } = DistributedExecProto::decode(buf).map_err(|err| proto_error(format!("{err}")))?
55        else {
56            return Err(proto_error(
57                "Expected DistributedExecNode in DistributedExecProto",
58            ));
59        };
60
61        fn parse_stage_proto(
62            proto: Option<StageProto>,
63            inputs: &[Arc<dyn ExecutionPlan>],
64        ) -> Result<Stage, DataFusionError> {
65            let Some(proto) = proto else {
66                return Err(proto_error("Empty StageProto"));
67            };
68
69            Ok(Stage {
70                query_id: uuid::Uuid::from_slice(proto.query_id.as_ref())
71                    .map_err(|_| proto_error("Invalid query_id in StageProto"))?,
72                num: proto.num as usize,
73                plan: inputs.first().cloned(),
74                tasks: decode_tasks(proto.tasks)?,
75            })
76        }
77
78        match distributed_exec_node {
79            DistributedExecNode::NetworkHashShuffle(NetworkShuffleExecProto {
80                schema,
81                partitioning,
82                input_stage,
83            }) => {
84                let schema: Schema = schema
85                    .as_ref()
86                    .map(|s| s.try_into())
87                    .ok_or(proto_error("NetworkShuffleExec is missing schema"))??;
88
89                let partitioning = parse_protobuf_partitioning(
90                    partitioning.as_ref(),
91                    ctx,
92                    &schema,
93                    &DistributedCodec {},
94                    &DefaultPhysicalProtoConverter,
95                )?
96                .ok_or(proto_error("NetworkShuffleExec is missing partitioning"))?;
97
98                Ok(Arc::new(new_network_hash_shuffle_exec(
99                    partitioning,
100                    Arc::new(schema),
101                    parse_stage_proto(input_stage, inputs)?,
102                )))
103            }
104            DistributedExecNode::NetworkCoalesceTasks(NetworkCoalesceExecProto {
105                schema,
106                partitioning,
107                input_stage,
108            }) => {
109                let schema: Schema = schema
110                    .as_ref()
111                    .map(|s| s.try_into())
112                    .ok_or(proto_error("NetworkCoalesceExec is missing schema"))??;
113
114                let partitioning = parse_protobuf_partitioning(
115                    partitioning.as_ref(),
116                    ctx,
117                    &schema,
118                    &DistributedCodec {},
119                    &DefaultPhysicalProtoConverter,
120                )?
121                .ok_or(proto_error("NetworkCoalesceExec is missing partitioning"))?;
122
123                Ok(Arc::new(new_network_coalesce_tasks_exec(
124                    partitioning,
125                    Arc::new(schema),
126                    parse_stage_proto(input_stage, inputs)?,
127                )))
128            }
129            DistributedExecNode::PartitionIsolator(PartitionIsolatorExecProto { n_tasks }) => {
130                if inputs.len() != 1 {
131                    return Err(proto_error(format!(
132                        "PartitionIsolatorExec expects exactly one child, got {}",
133                        inputs.len()
134                    )));
135                }
136
137                let child = inputs.first().unwrap();
138
139                Ok(Arc::new(PartitionIsolatorExec::new(
140                    child.clone(),
141                    n_tasks as usize,
142                )))
143            }
144            DistributedExecNode::NetworkBroadcast(NetworkBroadcastExecProto {
145                schema,
146                partitioning,
147                input_stage,
148            }) => {
149                let schema: Schema = schema
150                    .as_ref()
151                    .map(|s| s.try_into())
152                    .ok_or(proto_error("NetworkBroadcastExec is missing schema"))??;
153
154                let partitioning = parse_protobuf_partitioning(
155                    partitioning.as_ref(),
156                    ctx,
157                    &schema,
158                    &DistributedCodec {},
159                    &DefaultPhysicalProtoConverter,
160                )?
161                .ok_or(proto_error("NetworkBroadcastExec is missing partitioning"))?;
162
163                Ok(Arc::new(new_network_broadcast_exec(
164                    partitioning,
165                    Arc::new(schema),
166                    parse_stage_proto(input_stage, inputs)?,
167                )))
168            }
169            DistributedExecNode::Broadcast(BroadcastExecProto {
170                consumer_task_count,
171            }) => {
172                if inputs.len() != 1 {
173                    return Err(proto_error(format!(
174                        "BroadcastExec expects exactly one child, got {}",
175                        inputs.len()
176                    )));
177                }
178
179                let child = inputs.first().unwrap();
180                Ok(Arc::new(BroadcastExec::new(
181                    child.clone(),
182                    consumer_task_count as usize,
183                )))
184            }
185            DistributedExecNode::ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto {
186                partition_count,
187                task_idx_map,
188            }) => {
189                // Building a UnionExec just to get the properties out of it is not the most
190                // efficient thing to do. However, it's the easiest way of getting the properties
191                // for the ChildrenIsolatorUnionExec without copy-pasting in this project
192                // all the machinery that builds them for UnionExec.
193                let mut properties = UnionExec::try_new(inputs.to_vec())?
194                    .properties()
195                    .as_ref()
196                    .clone();
197                properties.partitioning =
198                    Partitioning::UnknownPartitioning(partition_count as usize);
199
200                Ok(Arc::new(ChildrenIsolatorUnionExec {
201                    properties: Arc::new(properties),
202                    metrics: Default::default(),
203                    children: inputs.to_vec(),
204                    task_idx_map: task_idx_map
205                        .iter()
206                        .map(|entry| {
207                            entry
208                                .child_ctx
209                                .iter()
210                                .map(|child_ctx| {
211                                    (
212                                        child_ctx.child_idx as usize,
213                                        DistributedTaskContext {
214                                            task_index: child_ctx.task_idx as usize,
215                                            task_count: child_ctx.task_count as usize,
216                                        },
217                                    )
218                                })
219                                .collect_vec()
220                        })
221                        .collect(),
222                }))
223            }
224        }
225    }
226
227    fn try_encode(
228        &self,
229        node: Arc<dyn ExecutionPlan>,
230        buf: &mut Vec<u8>,
231    ) -> datafusion::common::Result<()> {
232        fn encode_stage_proto(stage: &Stage) -> Result<StageProto, DataFusionError> {
233            Ok(StageProto {
234                query_id: Bytes::from(stage.query_id.as_bytes().to_vec()),
235                num: stage.num as u64,
236                tasks: encode_tasks(&stage.tasks),
237            })
238        }
239
240        if let Some(node) = node.as_any().downcast_ref::<NetworkShuffleExec>() {
241            let inner = NetworkShuffleExecProto {
242                schema: Some(node.schema().try_into()?),
243                partitioning: Some(serialize_partitioning(
244                    node.properties().output_partitioning(),
245                    &DistributedCodec {},
246                    &DefaultPhysicalProtoConverter,
247                )?),
248                input_stage: Some(encode_stage_proto(node.input_stage())?),
249            };
250
251            let wrapper = DistributedExecProto {
252                node: Some(DistributedExecNode::NetworkHashShuffle(inner)),
253            };
254
255            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
256        } else if let Some(node) = node.as_any().downcast_ref::<NetworkCoalesceExec>() {
257            let inner = NetworkCoalesceExecProto {
258                schema: Some(node.schema().try_into()?),
259                partitioning: Some(serialize_partitioning(
260                    node.properties().output_partitioning(),
261                    &DistributedCodec {},
262                    &DefaultPhysicalProtoConverter,
263                )?),
264                input_stage: Some(encode_stage_proto(node.input_stage())?),
265            };
266
267            let wrapper = DistributedExecProto {
268                node: Some(DistributedExecNode::NetworkCoalesceTasks(inner)),
269            };
270
271            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
272        } else if let Some(node) = node.as_any().downcast_ref::<PartitionIsolatorExec>() {
273            let inner = PartitionIsolatorExecProto {
274                n_tasks: node.n_tasks as u64,
275            };
276
277            let wrapper = DistributedExecProto {
278                node: Some(DistributedExecNode::PartitionIsolator(inner)),
279            };
280
281            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
282        } else if let Some(node) = node.as_any().downcast_ref::<NetworkBroadcastExec>() {
283            let inner = NetworkBroadcastExecProto {
284                schema: Some(node.schema().try_into()?),
285                partitioning: Some(serialize_partitioning(
286                    node.properties().output_partitioning(),
287                    &DistributedCodec {},
288                    &DefaultPhysicalProtoConverter,
289                )?),
290                input_stage: Some(encode_stage_proto(node.input_stage())?),
291            };
292
293            let wrapper = DistributedExecProto {
294                node: Some(DistributedExecNode::NetworkBroadcast(inner)),
295            };
296
297            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
298        } else if let Some(node) = node.as_any().downcast_ref::<BroadcastExec>() {
299            let inner = BroadcastExecProto {
300                consumer_task_count: node.consumer_task_count() as u64,
301            };
302
303            let wrapper = DistributedExecProto {
304                node: Some(DistributedExecNode::Broadcast(inner)),
305            };
306
307            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
308        } else if let Some(node) = node.as_any().downcast_ref::<ChildrenIsolatorUnionExec>() {
309            let inner = ChildrenIsolatorUnionExecProto {
310                partition_count: node.properties().output_partitioning().partition_count() as u64,
311                task_idx_map: node
312                    .task_idx_map
313                    .iter()
314                    .map(|v| TaskIdxMapEntryProto {
315                        child_ctx: v
316                            .iter()
317                            .map(|(child_idx, task_ctx)| ChildIdxWithTaskContextProto {
318                                child_idx: *child_idx as u64,
319                                task_idx: task_ctx.task_index as u64,
320                                task_count: task_ctx.task_count as u64,
321                            })
322                            .collect_vec(),
323                    })
324                    .collect_vec(),
325            };
326
327            let wrapper = DistributedExecProto {
328                node: Some(DistributedExecNode::ChildrenIsolatorUnion(inner)),
329            };
330
331            wrapper.encode(buf).map_err(|e| proto_error(format!("{e}")))
332        } else {
333            Err(proto_error(format!("Unexpected plan {}", node.name())))
334        }
335    }
336}
337
338#[derive(Clone, PartialEq, ::prost::Message)]
339pub struct StageProto {
340    /// Our query id
341    #[prost(bytes, tag = "1")]
342    pub query_id: Bytes,
343    /// Our stage number
344    #[prost(uint64, tag = "2")]
345    pub num: u64,
346    /// Our tasks which tell us how finely grained to execute the partitions in
347    /// the plan
348    #[prost(message, repeated, tag = "3")]
349    pub tasks: Vec<ExecutionTaskProto>,
350}
351
352#[derive(Clone, PartialEq, ::prost::Message)]
353pub struct ExecutionTaskProto {
354    /// The url of the worker that will execute this task.  A None value is interpreted as
355    /// unassigned.
356    #[prost(string, optional, tag = "1")]
357    pub url_str: Option<String>,
358}
359
360#[derive(Clone, PartialEq, ::prost::Message)]
361pub struct DistributedExecProto {
362    #[prost(oneof = "DistributedExecNode", tags = "1, 2, 3, 4, 5, 6")]
363    pub node: Option<DistributedExecNode>,
364}
365
366#[derive(Clone, PartialEq, prost::Oneof)]
367pub enum DistributedExecNode {
368    #[prost(message, tag = "1")]
369    NetworkHashShuffle(NetworkShuffleExecProto),
370    #[prost(message, tag = "2")]
371    NetworkCoalesceTasks(NetworkCoalesceExecProto),
372    #[prost(message, tag = "3")]
373    PartitionIsolator(PartitionIsolatorExecProto),
374    #[prost(message, tag = "4")]
375    ChildrenIsolatorUnion(ChildrenIsolatorUnionExecProto),
376    #[prost(message, tag = "5")]
377    NetworkBroadcast(NetworkBroadcastExecProto),
378    #[prost(message, tag = "6")]
379    Broadcast(BroadcastExecProto),
380}
381
382#[derive(Clone, PartialEq, ::prost::Message)]
383pub struct PartitionIsolatorExecProto {
384    #[prost(uint64, tag = "1")]
385    pub n_tasks: u64,
386}
387
388/// Protobuf representation of the [NetworkShuffleExec] physical node. It serves as
389/// an intermediate format for serializing/deserializing [NetworkShuffleExec] nodes
390/// to send them over the wire.
391#[derive(Clone, PartialEq, ::prost::Message)]
392pub struct NetworkShuffleExecProto {
393    #[prost(message, optional, tag = "1")]
394    schema: Option<protobuf::Schema>,
395    #[prost(message, optional, tag = "2")]
396    partitioning: Option<protobuf::Partitioning>,
397    #[prost(message, optional, tag = "3")]
398    input_stage: Option<StageProto>,
399}
400
401#[derive(Clone, PartialEq, ::prost::Message)]
402pub struct ChildrenIsolatorUnionExecProto {
403    #[prost(uint64, tag = "1")]
404    partition_count: u64,
405    #[prost(message, repeated, tag = "2")]
406    task_idx_map: Vec<TaskIdxMapEntryProto>,
407}
408
409#[derive(Clone, PartialEq, ::prost::Message)]
410pub struct TaskIdxMapEntryProto {
411    #[prost(message, repeated, tag = "1")]
412    child_ctx: Vec<ChildIdxWithTaskContextProto>,
413}
414
415#[derive(Clone, PartialEq, ::prost::Message)]
416pub struct ChildIdxWithTaskContextProto {
417    #[prost(uint64, tag = "1")]
418    child_idx: u64,
419    #[prost(uint64, tag = "2")]
420    task_idx: u64,
421    #[prost(uint64, tag = "3")]
422    task_count: u64,
423}
424
425fn new_network_hash_shuffle_exec(
426    partitioning: Partitioning,
427    schema: SchemaRef,
428    input_stage: Stage,
429) -> NetworkShuffleExec {
430    NetworkShuffleExec {
431        properties: Arc::new(PlanProperties::new(
432            EquivalenceProperties::new(schema),
433            partitioning,
434            EmissionType::Incremental,
435            Boundedness::Bounded,
436        )),
437        worker_connections: WorkerConnectionPool::new(input_stage.tasks.len()),
438        input_stage,
439        metrics_collection: Default::default(),
440    }
441}
442
443/// Protobuf representation of the [NetworkShuffleExec] physical node. It serves as
444/// an intermediate format for serializing/deserializing [NetworkShuffleExec] nodes
445/// to send them over the wire.
446#[derive(Clone, PartialEq, ::prost::Message)]
447pub struct NetworkCoalesceExecProto {
448    #[prost(message, optional, tag = "1")]
449    schema: Option<protobuf::Schema>,
450    #[prost(message, optional, tag = "2")]
451    partitioning: Option<protobuf::Partitioning>,
452    #[prost(message, optional, tag = "3")]
453    input_stage: Option<StageProto>,
454}
455
456fn new_network_coalesce_tasks_exec(
457    partitioning: Partitioning,
458    schema: SchemaRef,
459    input_stage: Stage,
460) -> NetworkCoalesceExec {
461    NetworkCoalesceExec {
462        properties: Arc::new(PlanProperties::new(
463            EquivalenceProperties::new(schema),
464            partitioning,
465            EmissionType::Incremental,
466            Boundedness::Bounded,
467        )),
468        worker_connections: WorkerConnectionPool::new(input_stage.tasks.len()),
469        input_stage,
470        metrics_collection: Default::default(),
471    }
472}
473
474#[derive(Clone, PartialEq, ::prost::Message)]
475pub struct NetworkBroadcastExecProto {
476    #[prost(message, optional, tag = "1")]
477    schema: Option<protobuf::Schema>,
478    #[prost(message, optional, tag = "2")]
479    partitioning: Option<protobuf::Partitioning>,
480    #[prost(message, optional, tag = "3")]
481    input_stage: Option<StageProto>,
482}
483
484#[derive(Clone, PartialEq, ::prost::Message)]
485pub struct BroadcastExecProto {
486    #[prost(uint64, tag = "1")]
487    pub consumer_task_count: u64,
488}
489
490fn new_network_broadcast_exec(
491    partitioning: Partitioning,
492    schema: SchemaRef,
493    input_stage: Stage,
494) -> NetworkBroadcastExec {
495    NetworkBroadcastExec {
496        properties: Arc::new(PlanProperties::new(
497            EquivalenceProperties::new(schema),
498            partitioning,
499            EmissionType::Incremental,
500            Boundedness::Bounded,
501        )),
502        worker_connections: WorkerConnectionPool::new(input_stage.tasks.len()),
503        input_stage,
504        metrics_collection: Default::default(),
505    }
506}
507
508fn encode_tasks(tasks: &[ExecutionTask]) -> Vec<ExecutionTaskProto> {
509    tasks
510        .iter()
511        .map(|task| ExecutionTaskProto {
512            url_str: task.url.as_ref().map(|v| v.to_string()),
513        })
514        .collect()
515}
516
517fn decode_tasks(tasks: Vec<ExecutionTaskProto>) -> Result<Vec<ExecutionTask>, DataFusionError> {
518    tasks
519        .into_iter()
520        .map(|task| {
521            Ok(ExecutionTask {
522                url: task
523                    .url_str
524                    .map(|u| {
525                        Url::parse(&u).map_err(|_| internal_datafusion_err!("Invalid URL: {u}"))
526                    })
527                    .transpose()?,
528            })
529        })
530        .collect()
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use datafusion::arrow::datatypes::{DataType, Field};
537    use datafusion::physical_expr::LexOrdering;
538    use datafusion::physical_plan::empty::EmptyExec;
539    use datafusion::{
540        physical_expr::{Partitioning, PhysicalSortExpr, expressions::Column, expressions::col},
541        physical_plan::{ExecutionPlan, displayable, sorts::sort::SortExec, union::UnionExec},
542    };
543
544    use datafusion::prelude::SessionContext;
545
546    fn empty_exec() -> Arc<dyn ExecutionPlan> {
547        Arc::new(EmptyExec::new(SchemaRef::new(Schema::empty())))
548    }
549
550    fn dummy_stage() -> Stage {
551        Stage {
552            query_id: Default::default(),
553            num: 0,
554            plan: None,
555            tasks: vec![],
556        }
557    }
558
559    fn dummy_stage_with_plan() -> Stage {
560        Stage {
561            query_id: Default::default(),
562            num: 0,
563            plan: Some(empty_exec()),
564            tasks: vec![],
565        }
566    }
567
568    fn schema_i32(name: &str) -> Arc<Schema> {
569        Arc::new(Schema::new(vec![Field::new(name, DataType::Int32, false)]))
570    }
571
572    fn repr(plan: &Arc<dyn ExecutionPlan>) -> String {
573        displayable(plan.as_ref()).indent(true).to_string()
574    }
575
576    fn create_context() -> Arc<TaskContext> {
577        SessionContext::new().task_ctx()
578    }
579
580    #[test]
581    fn test_roundtrip_single_flight() -> datafusion::common::Result<()> {
582        let codec = DistributedCodec;
583        let ctx = create_context();
584
585        let schema = schema_i32("a");
586        let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
587        let plan: Arc<dyn ExecutionPlan> =
588            Arc::new(new_network_hash_shuffle_exec(part, schema, dummy_stage()));
589
590        let mut buf = Vec::new();
591        codec.try_encode(plan.clone(), &mut buf)?;
592
593        let decoded = codec.try_decode(&buf, &[], &ctx)?;
594        assert_eq!(repr(&plan), repr(&decoded));
595
596        Ok(())
597    }
598
599    #[test]
600    fn test_roundtrip_isolator_flight() -> datafusion::common::Result<()> {
601        let codec = DistributedCodec;
602        let ctx = create_context();
603
604        let schema = schema_i32("b");
605        let flight = Arc::new(new_network_hash_shuffle_exec(
606            Partitioning::UnknownPartitioning(1),
607            schema,
608            dummy_stage(),
609        ));
610
611        let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(flight.clone(), 1));
612
613        let mut buf = Vec::new();
614        codec.try_encode(plan.clone(), &mut buf)?;
615
616        let decoded = codec.try_decode(&buf, &[flight], &ctx)?;
617        assert_eq!(repr(&plan), repr(&decoded));
618
619        Ok(())
620    }
621
622    #[test]
623    fn test_roundtrip_isolator_union() -> datafusion::common::Result<()> {
624        let codec = DistributedCodec;
625        let ctx = create_context();
626
627        let schema = schema_i32("c");
628        let left = Arc::new(new_network_hash_shuffle_exec(
629            Partitioning::RoundRobinBatch(2),
630            schema.clone(),
631            dummy_stage(),
632        ));
633        let right = Arc::new(new_network_hash_shuffle_exec(
634            Partitioning::RoundRobinBatch(2),
635            schema.clone(),
636            dummy_stage(),
637        ));
638
639        let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
640        let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(union.clone(), 1));
641
642        let mut buf = Vec::new();
643        codec.try_encode(plan.clone(), &mut buf)?;
644
645        let decoded = codec.try_decode(&buf, &[union], &ctx)?;
646        assert_eq!(repr(&plan), repr(&decoded));
647
648        Ok(())
649    }
650
651    #[test]
652    fn test_roundtrip_isolator_sort_flight() -> datafusion::common::Result<()> {
653        let codec = DistributedCodec;
654        let ctx = create_context();
655
656        let schema = schema_i32("d");
657        let flight = Arc::new(new_network_hash_shuffle_exec(
658            Partitioning::UnknownPartitioning(1),
659            schema.clone(),
660            dummy_stage(),
661        ));
662
663        let sort_expr = PhysicalSortExpr {
664            expr: col("d", &schema)?,
665            options: Default::default(),
666        };
667        let sort = Arc::new(SortExec::new(
668            LexOrdering::new(vec![sort_expr]).unwrap(),
669            flight.clone(),
670        ));
671
672        let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(sort.clone(), 1));
673
674        let mut buf = Vec::new();
675        codec.try_encode(plan.clone(), &mut buf)?;
676
677        let decoded = codec.try_decode(&buf, &[sort], &ctx)?;
678        assert_eq!(repr(&plan), repr(&decoded));
679
680        Ok(())
681    }
682
683    #[test]
684    fn test_roundtrip_single_flight_coalesce() -> datafusion::common::Result<()> {
685        let codec = DistributedCodec;
686        let ctx = create_context();
687
688        let schema = schema_i32("e");
689        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
690            Partitioning::RoundRobinBatch(3),
691            schema,
692            dummy_stage(),
693        ));
694
695        let mut buf = Vec::new();
696        codec.try_encode(plan.clone(), &mut buf)?;
697
698        let decoded = codec.try_decode(&buf, &[], &ctx)?;
699        assert_eq!(repr(&plan), repr(&decoded));
700
701        Ok(())
702    }
703
704    #[test]
705    fn test_roundtrip_single_flight_with_plan() -> datafusion::common::Result<()> {
706        let codec = DistributedCodec;
707        let ctx = create_context();
708
709        let schema = schema_i32("a");
710        let part = Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 4);
711        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_hash_shuffle_exec(
712            part,
713            schema,
714            dummy_stage_with_plan(),
715        ));
716
717        let mut buf = Vec::new();
718        codec.try_encode(plan.clone(), &mut buf)?;
719
720        let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
721        assert_eq!(repr(&plan), repr(&decoded));
722
723        Ok(())
724    }
725
726    #[test]
727    fn test_roundtrip_single_flight_coalesce_with_plan() -> datafusion::common::Result<()> {
728        let codec = DistributedCodec;
729        let ctx = create_context();
730
731        let schema = schema_i32("e");
732        let plan: Arc<dyn ExecutionPlan> = Arc::new(new_network_coalesce_tasks_exec(
733            Partitioning::RoundRobinBatch(3),
734            schema,
735            dummy_stage_with_plan(),
736        ));
737
738        let mut buf = Vec::new();
739        codec.try_encode(plan.clone(), &mut buf)?;
740
741        let decoded = codec.try_decode(&buf, &[empty_exec()], &ctx)?;
742        assert_eq!(repr(&plan), repr(&decoded));
743
744        Ok(())
745    }
746
747    #[test]
748    fn test_roundtrip_isolator_flight_coalesce() -> datafusion::common::Result<()> {
749        let codec = DistributedCodec;
750        let ctx = create_context();
751
752        let schema = schema_i32("f");
753        let flight = Arc::new(new_network_coalesce_tasks_exec(
754            Partitioning::UnknownPartitioning(1),
755            schema,
756            dummy_stage(),
757        ));
758
759        let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(flight.clone(), 1));
760
761        let mut buf = Vec::new();
762        codec.try_encode(plan.clone(), &mut buf)?;
763
764        let decoded = codec.try_decode(&buf, &[flight], &ctx)?;
765        assert_eq!(repr(&plan), repr(&decoded));
766
767        Ok(())
768    }
769
770    #[test]
771    fn test_roundtrip_isolator_union_coalesce() -> datafusion::common::Result<()> {
772        let codec = DistributedCodec;
773        let ctx = create_context();
774
775        let schema = schema_i32("g");
776        let left = Arc::new(new_network_coalesce_tasks_exec(
777            Partitioning::RoundRobinBatch(2),
778            schema.clone(),
779            dummy_stage(),
780        ));
781        let right = Arc::new(new_network_coalesce_tasks_exec(
782            Partitioning::RoundRobinBatch(2),
783            schema.clone(),
784            dummy_stage(),
785        ));
786
787        let union = UnionExec::try_new(vec![left.clone(), right.clone()])?;
788        let plan: Arc<dyn ExecutionPlan> = Arc::new(PartitionIsolatorExec::new(union.clone(), 3));
789
790        let mut buf = Vec::new();
791        codec.try_encode(plan.clone(), &mut buf)?;
792
793        let decoded = codec.try_decode(&buf, &[union], &ctx)?;
794        assert_eq!(repr(&plan), repr(&decoded));
795
796        Ok(())
797    }
798
799    #[test]
800    fn test_roundtrip_children_isolator_union() -> datafusion::common::Result<()> {
801        let codec = DistributedCodec;
802        let ctx = create_context();
803
804        let schema = schema_i32("h");
805        let left = Arc::new(new_network_hash_shuffle_exec(
806            Partitioning::RoundRobinBatch(2),
807            schema.clone(),
808            dummy_stage(),
809        )) as Arc<dyn ExecutionPlan>;
810        let right = Arc::new(new_network_hash_shuffle_exec(
811            Partitioning::RoundRobinBatch(2),
812            schema.clone(),
813            dummy_stage(),
814        )) as Arc<dyn ExecutionPlan>;
815
816        let plan: Arc<dyn ExecutionPlan> =
817            Arc::new(ChildrenIsolatorUnionExec::from_children_and_task_counts(
818                vec![left.clone(), right.clone()],
819                vec![2, 2],
820                4,
821            )?);
822
823        let mut buf = Vec::new();
824        codec.try_encode(plan.clone(), &mut buf)?;
825
826        let decoded = codec.try_decode(&buf, &[left, right], &ctx)?;
827        assert_eq!(repr(&plan), repr(&decoded));
828
829        Ok(())
830    }
831}