Skip to main content

datum/graph/
junctions.rs

1use super::*;
2
3#[derive(Clone, Debug)]
4pub struct Identity<T: 'static> {
5    _marker: PhantomData<fn() -> T>,
6}
7
8impl<T: 'static> Identity<T> {
9    #[must_use]
10    pub fn new() -> Self {
11        Self {
12            _marker: PhantomData,
13        }
14    }
15}
16
17impl<T: 'static> Default for Identity<T> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<T> GraphStage for Identity<T>
24where
25    T: Clone + Send + 'static,
26{
27    type Shape = FlowShape<T, T>;
28
29    fn name(&self) -> &str {
30        "Identity"
31    }
32
33    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
34        let first_id = next_port_id_block(2);
35        FlowShape::new(
36            Inlet::with_arc_name(first_id, identity_inlet_name()),
37            Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
38        )
39    }
40
41    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
42        StageSpec::identity(shape.inlets(), shape.outlets())
43    }
44
45    fn stage_spec_with_ports(
46        &self,
47        _shape: &Self::Shape,
48        inlets: Vec<AnyInlet>,
49        outlets: Vec<AnyOutlet>,
50    ) -> StageSpec {
51        StageSpec::identity(inlets, outlets)
52    }
53}
54
55#[derive(Clone)]
56pub struct MapStage<In: 'static, Out: 'static> {
57    f: Arc<dyn Fn(In) -> Out + Send + Sync>,
58    _marker: PhantomData<fn(In) -> Out>,
59}
60
61impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("MapStage")
64            .field("name", &"Map")
65            .finish_non_exhaustive()
66    }
67}
68
69impl<In: 'static, Out: 'static> MapStage<In, Out> {
70    #[must_use]
71    pub fn new<F>(f: F) -> Self
72    where
73        F: Fn(In) -> Out + Send + Sync + 'static,
74    {
75        Self {
76            f: Arc::new(f),
77            _marker: PhantomData,
78        }
79    }
80}
81
82impl<In, Out> GraphStage for MapStage<In, Out>
83where
84    In: Clone + Send + 'static,
85    Out: Clone + Send + 'static,
86{
87    type Shape = FlowShape<In, Out>;
88
89    fn name(&self) -> &str {
90        "Map"
91    }
92
93    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
94        let first_id = next_port_id_block(2);
95        FlowShape::new(
96            Inlet::with_arc_name(first_id, map_inlet_name()),
97            Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
98        )
99    }
100
101    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
102        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
103    }
104
105    fn stage_spec_with_ports(
106        &self,
107        _shape: &Self::Shape,
108        inlets: Vec<AnyInlet>,
109        outlets: Vec<AnyOutlet>,
110    ) -> StageSpec {
111        let f = Arc::clone(&self.f);
112        let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
113        let mapper = Arc::new(move |value: DatumValue| {
114            let value: In = downcast_datum(value, "map", || "Map.in")?;
115            Ok(datum(f(value)))
116        });
117        StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
118    }
119}
120
121#[derive(Clone, Debug)]
122pub struct Buffer<T: 'static> {
123    capacity: usize,
124    strategy: OverflowStrategy,
125    _marker: PhantomData<fn() -> T>,
126}
127
128impl<T: 'static> Buffer<T> {
129    #[must_use]
130    pub fn new(capacity: usize, strategy: OverflowStrategy) -> Self {
131        assert!(capacity > 0, "buffer capacity must be greater than zero");
132        Self {
133            capacity,
134            strategy,
135            _marker: PhantomData,
136        }
137    }
138}
139
140impl<T> GraphStage for Buffer<T>
141where
142    T: Clone + Send + 'static,
143{
144    type Shape = FlowShape<T, T>;
145
146    fn name(&self) -> &str {
147        "Buffer"
148    }
149
150    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
151        let first_id = next_port_id_block(2);
152        FlowShape::new(
153            Inlet::with_id(first_id, "Buffer.in"),
154            Outlet::with_id(first_id.offset(1), "Buffer.out"),
155        )
156    }
157
158    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
159        // In the synchronous fused executor a buffer drains one element per step
160        // and never overflows, so it is a pass-through on the typed cyclic path.
161        // The erased executor ignores this and still runs the full buffer logic.
162        StageSpec::opaque(self.name(), shape.inlets(), shape.outlets())
163            .with_typed_cyclic(TypedCyclicOp::BufferPassthrough)
164    }
165
166    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
167        struct State<T> {
168            queue: VecDeque<T>,
169            upstream_closed: bool,
170        }
171
172        fn lock_state<T>(
173            state: &Arc<Mutex<State<T>>>,
174        ) -> StreamResult<std::sync::MutexGuard<'_, State<T>>> {
175            state
176                .lock()
177                .map_err(|_| StreamError::Failed("graph buffer state poisoned".into()))
178        }
179
180        fn enqueue<T>(
181            state: &mut State<T>,
182            value: T,
183            capacity: usize,
184            strategy: OverflowStrategy,
185        ) -> StreamResult<()> {
186            if state.queue.len() < capacity {
187                state.queue.push_back(value);
188                return Ok(());
189            }
190
191            match strategy {
192                OverflowStrategy::Backpressure | OverflowStrategy::Fail => Err(
193                    StreamError::Failed(format!("Buffer overflow (max capacity was: {capacity})!")),
194                ),
195                OverflowStrategy::DropHead => {
196                    let _ = state.queue.pop_front();
197                    state.queue.push_back(value);
198                    Ok(())
199                }
200                OverflowStrategy::DropTail => {
201                    let _ = state.queue.pop_back();
202                    state.queue.push_back(value);
203                    Ok(())
204                }
205                OverflowStrategy::DropBuffer => {
206                    state.queue.clear();
207                    state.queue.push_back(value);
208                    Ok(())
209                }
210                OverflowStrategy::DropNew => Ok(()),
211            }
212        }
213
214        fn drain_one<T>(
215            logic: &mut GraphStageLogic,
216            outlet: &Outlet<T>,
217            state: &Arc<Mutex<State<T>>>,
218        ) -> StreamResult<()>
219        where
220            T: Clone + Send + 'static,
221        {
222            if !logic.is_available(outlet) {
223                return Ok(());
224            }
225            let next = lock_state(state)?.queue.pop_front();
226            if let Some(value) = next {
227                logic.push(outlet, value)?;
228            }
229            Ok(())
230        }
231
232        fn maybe_pull<T>(
233            logic: &mut GraphStageLogic,
234            inlet: &Inlet<T>,
235            state: &Arc<Mutex<State<T>>>,
236            capacity: usize,
237        ) -> StreamResult<()>
238        where
239            T: 'static,
240        {
241            let can_pull = {
242                let state = lock_state(state)?;
243                !state.upstream_closed && state.queue.len() < capacity
244            };
245            if can_pull && !logic.has_been_pulled(inlet) {
246                logic.pull(inlet)?;
247            }
248            Ok(())
249        }
250
251        fn maybe_complete<T>(
252            logic: &mut GraphStageLogic,
253            outlet: &Outlet<T>,
254            state: &Arc<Mutex<State<T>>>,
255        ) -> StreamResult<()>
256        where
257            T: 'static,
258        {
259            let done = {
260                let state = lock_state(state)?;
261                state.upstream_closed && state.queue.is_empty()
262            };
263            if done && !logic.is_closed(outlet) {
264                logic.complete(outlet)?;
265            }
266            Ok(())
267        }
268
269        struct In<T: 'static> {
270            inlet: Inlet<T>,
271            outlet: Outlet<T>,
272            state: Arc<Mutex<State<T>>>,
273            capacity: usize,
274            strategy: OverflowStrategy,
275        }
276
277        impl<T> InHandler for In<T>
278        where
279            T: Clone + Send + 'static,
280        {
281            fn on_push(
282                &mut self,
283                logic: &mut GraphStageLogic,
284                _inlet: AnyInlet,
285            ) -> StreamResult<()> {
286                let value = logic.grab(&self.inlet)?;
287                {
288                    let mut state = lock_state(&self.state)?;
289                    enqueue(&mut state, value, self.capacity, self.strategy)?;
290                }
291                drain_one(logic, &self.outlet, &self.state)?;
292                maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
293                maybe_complete(logic, &self.outlet, &self.state)
294            }
295
296            fn on_upstream_finish(
297                &mut self,
298                logic: &mut GraphStageLogic,
299                _inlet: AnyInlet,
300            ) -> StreamResult<()> {
301                lock_state(&self.state)?.upstream_closed = true;
302                drain_one(logic, &self.outlet, &self.state)?;
303                maybe_complete(logic, &self.outlet, &self.state)
304            }
305        }
306
307        struct Out<T: 'static> {
308            inlet: Inlet<T>,
309            outlet: Outlet<T>,
310            state: Arc<Mutex<State<T>>>,
311            capacity: usize,
312        }
313
314        impl<T> OutHandler for Out<T>
315        where
316            T: Clone + Send + 'static,
317        {
318            fn on_pull(
319                &mut self,
320                logic: &mut GraphStageLogic,
321                _outlet: AnyOutlet,
322            ) -> StreamResult<()> {
323                drain_one(logic, &self.outlet, &self.state)?;
324                maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
325                maybe_complete(logic, &self.outlet, &self.state)
326            }
327
328            fn on_downstream_finish(
329                &mut self,
330                logic: &mut GraphStageLogic,
331                _outlet: AnyOutlet,
332            ) -> StreamResult<()> {
333                logic.complete_stage()
334            }
335        }
336
337        let state = Arc::new(Mutex::new(State {
338            queue: VecDeque::new(),
339            upstream_closed: false,
340        }));
341        let mut logic = GraphStageLogic::new(shape);
342        logic.pull(&shape.inlet()).unwrap();
343        logic
344            .set_handler(
345                &shape.inlet(),
346                Box::new(In {
347                    inlet: shape.inlet(),
348                    outlet: shape.outlet(),
349                    state: Arc::clone(&state),
350                    capacity: self.capacity,
351                    strategy: self.strategy,
352                }),
353            )
354            .unwrap();
355        logic
356            .set_out_handler(
357                &shape.outlet(),
358                Box::new(Out {
359                    inlet: shape.inlet(),
360                    outlet: shape.outlet(),
361                    state,
362                    capacity: self.capacity,
363                }),
364            )
365            .unwrap();
366        logic
367    }
368}
369
370#[derive(Clone)]
371pub struct TakeWhile<T: 'static> {
372    predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
373    _marker: PhantomData<fn() -> T>,
374}
375
376impl<T: 'static> fmt::Debug for TakeWhile<T> {
377    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378        f.debug_struct("TakeWhile")
379            .field("name", &"TakeWhile")
380            .finish_non_exhaustive()
381    }
382}
383
384impl<T: 'static> TakeWhile<T> {
385    #[must_use]
386    pub fn new<F>(predicate: F) -> Self
387    where
388        F: Fn(&T) -> bool + Send + Sync + 'static,
389    {
390        Self {
391            predicate: Arc::new(predicate),
392            _marker: PhantomData,
393        }
394    }
395}
396
397impl<T> GraphStage for TakeWhile<T>
398where
399    T: Clone + Send + 'static,
400{
401    type Shape = FlowShape<T, T>;
402
403    fn name(&self) -> &str {
404        "TakeWhile"
405    }
406
407    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
408        let first_id = next_port_id_block(2);
409        FlowShape::new(
410            Inlet::with_id(first_id, "TakeWhile.in"),
411            Outlet::with_id(first_id.offset(1), "TakeWhile.out"),
412        )
413    }
414
415    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
416        // Surface the typed predicate so the cyclic typed kernel can apply it
417        // directly (no `DatumValue` boxing). Stored as `Arc<dyn Any>` and
418        // down-cast at plan time, mirroring the typed Map fn. The erased path is
419        // unaffected — it keeps running the `GraphStageLogic` below.
420        let predicate: Arc<dyn Fn(&T) -> bool + Send + Sync> = Arc::clone(&self.predicate);
421        StageSpec::opaque(self.name(), shape.inlets(), shape.outlets()).with_typed_cyclic(
422            TypedCyclicOp::TakeWhile(Arc::new(predicate) as Arc<dyn Any + Send + Sync>),
423        )
424    }
425
426    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
427        struct In<T: 'static> {
428            inlet: Inlet<T>,
429            outlet: Outlet<T>,
430            predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
431        }
432
433        impl<T> InHandler for In<T>
434        where
435            T: Clone + Send + 'static,
436        {
437            fn on_push(
438                &mut self,
439                logic: &mut GraphStageLogic,
440                _inlet: AnyInlet,
441            ) -> StreamResult<()> {
442                let value = logic.grab(&self.inlet)?;
443                if (self.predicate)(&value) {
444                    logic.emit(&self.outlet, value)
445                } else if logic.is_closed(&self.outlet) {
446                    Ok(())
447                } else {
448                    logic.complete(&self.outlet)
449                }
450            }
451
452            fn on_upstream_finish(
453                &mut self,
454                logic: &mut GraphStageLogic,
455                _inlet: AnyInlet,
456            ) -> StreamResult<()> {
457                if logic.is_closed(&self.outlet) {
458                    Ok(())
459                } else {
460                    logic.complete(&self.outlet)
461                }
462            }
463        }
464
465        struct Out<T: 'static> {
466            inlet: Inlet<T>,
467        }
468
469        impl<T> OutHandler for Out<T>
470        where
471            T: Clone + Send + 'static,
472        {
473            fn on_pull(
474                &mut self,
475                logic: &mut GraphStageLogic,
476                _outlet: AnyOutlet,
477            ) -> StreamResult<()> {
478                if !logic.has_been_pulled(&self.inlet) {
479                    logic.pull(&self.inlet)?;
480                }
481                Ok(())
482            }
483
484            fn on_downstream_finish(
485                &mut self,
486                logic: &mut GraphStageLogic,
487                _outlet: AnyOutlet,
488            ) -> StreamResult<()> {
489                logic.complete_stage()
490            }
491        }
492
493        let mut logic = GraphStageLogic::new(shape);
494        logic
495            .set_handler(
496                &shape.inlet(),
497                Box::new(In {
498                    inlet: shape.inlet(),
499                    outlet: shape.outlet(),
500                    predicate: Arc::clone(&self.predicate),
501                }),
502            )
503            .unwrap();
504        logic
505            .set_out_handler(
506                &shape.outlet(),
507                Box::new(Out {
508                    inlet: shape.inlet(),
509                }),
510            )
511            .unwrap();
512        logic
513    }
514}
515
516#[derive(Clone, Debug)]
517pub struct Broadcast<T: 'static> {
518    outputs: usize,
519    _marker: PhantomData<fn() -> T>,
520}
521
522impl<T: 'static> Broadcast<T> {
523    #[must_use]
524    pub fn new(outputs: usize) -> Self {
525        assert!(
526            outputs > 0,
527            "broadcast output count must be greater than zero"
528        );
529        Self {
530            outputs,
531            _marker: PhantomData,
532        }
533    }
534}
535
536impl<T> GraphStage for Broadcast<T>
537where
538    T: Clone + Send + 'static,
539{
540    type Shape = FanOutShape<T, T>;
541
542    fn name(&self) -> &str {
543        "Broadcast"
544    }
545
546    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
547        let inlet = allocator.inlet_arc(broadcast_inlet_name());
548        let outlets = (0..self.outputs)
549            .map(|index| allocator.outlet(format!("Broadcast.out{index}")))
550            .collect();
551        FanOutShape::new(inlet, outlets)
552    }
553
554    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
555        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
556    }
557
558    fn stage_spec_with_ports(
559        &self,
560        _shape: &Self::Shape,
561        inlets: Vec<AnyInlet>,
562        outlets: Vec<AnyOutlet>,
563    ) -> StageSpec {
564        StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
565    }
566}
567
568#[derive(Clone, Debug)]
569pub struct Balance<T: 'static> {
570    outputs: usize,
571    _marker: PhantomData<fn() -> T>,
572}
573
574impl<T: 'static> Balance<T> {
575    #[must_use]
576    pub fn new(outputs: usize) -> Self {
577        assert!(
578            outputs > 0,
579            "balance output count must be greater than zero"
580        );
581        Self {
582            outputs,
583            _marker: PhantomData,
584        }
585    }
586}
587
588impl<T> GraphStage for Balance<T>
589where
590    T: Clone + Send + 'static,
591{
592    type Shape = FanOutShape<T, T>;
593
594    fn name(&self) -> &str {
595        "Balance"
596    }
597
598    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
599        let inlet = allocator.inlet_arc(balance_inlet_name());
600        let outlets = (0..self.outputs)
601            .map(|index| allocator.outlet(format!("Balance.out{index}")))
602            .collect();
603        FanOutShape::new(inlet, outlets)
604    }
605
606    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
607        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
608    }
609
610    fn stage_spec_with_ports(
611        &self,
612        _shape: &Self::Shape,
613        inlets: Vec<AnyInlet>,
614        outlets: Vec<AnyOutlet>,
615    ) -> StageSpec {
616        StageSpec::balance(balance_stage_name(), inlets, outlets)
617    }
618}
619
620#[derive(Clone, Debug)]
621pub struct Merge<T: 'static> {
622    inputs: usize,
623    _marker: PhantomData<fn() -> T>,
624}
625
626impl<T: 'static> Merge<T> {
627    #[must_use]
628    pub fn new(inputs: usize) -> Self {
629        assert!(inputs > 0, "merge input count must be greater than zero");
630        Self {
631            inputs,
632            _marker: PhantomData,
633        }
634    }
635}
636
637impl<T> GraphStage for Merge<T>
638where
639    T: Clone + Send + 'static,
640{
641    type Shape = FanInShape<T, T>;
642
643    fn name(&self) -> &str {
644        "Merge"
645    }
646
647    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
648        let inlets = (0..self.inputs)
649            .map(|index| allocator.inlet(format!("Merge.in{index}")))
650            .collect();
651        FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
652    }
653
654    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
655        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
656    }
657
658    fn stage_spec_with_ports(
659        &self,
660        _shape: &Self::Shape,
661        inlets: Vec<AnyInlet>,
662        outlets: Vec<AnyOutlet>,
663    ) -> StageSpec {
664        StageSpec::merge(merge_stage_name(), inlets, outlets)
665    }
666}
667
668#[derive(Clone, Debug)]
669pub struct Concat<T: 'static> {
670    inputs: usize,
671    _marker: PhantomData<fn() -> T>,
672}
673
674impl<T: 'static> Concat<T> {
675    #[must_use]
676    pub fn new(inputs: usize) -> Self {
677        assert!(inputs > 1, "concat input count must be greater than one");
678        Self {
679            inputs,
680            _marker: PhantomData,
681        }
682    }
683}
684
685impl<T> GraphStage for Concat<T>
686where
687    T: Clone + Send + 'static,
688{
689    type Shape = FanInShape<T, T>;
690
691    fn name(&self) -> &str {
692        "Concat"
693    }
694
695    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
696        let inlets = (0..self.inputs)
697            .map(|index| allocator.inlet(format!("Concat.in{index}")))
698            .collect();
699        FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
700    }
701
702    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
703        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
704    }
705
706    fn stage_spec_with_ports(
707        &self,
708        _shape: &Self::Shape,
709        inlets: Vec<AnyInlet>,
710        outlets: Vec<AnyOutlet>,
711    ) -> StageSpec {
712        StageSpec::concat(concat_stage_name(), inlets, outlets)
713    }
714}
715
716#[derive(Clone, Debug)]
717pub struct OrElse<T: 'static> {
718    _marker: PhantomData<fn() -> T>,
719}
720
721impl<T: 'static> OrElse<T> {
722    #[must_use]
723    pub fn new() -> Self {
724        Self {
725            _marker: PhantomData,
726        }
727    }
728}
729
730impl<T: 'static> Default for OrElse<T> {
731    fn default() -> Self {
732        Self::new()
733    }
734}
735
736impl<T> GraphStage for OrElse<T>
737where
738    T: Clone + Send + 'static,
739{
740    type Shape = FanInShape<T, T>;
741
742    fn name(&self) -> &str {
743        "OrElse"
744    }
745
746    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
747        let inlets = vec![
748            allocator.inlet_arc(or_else_primary_name()),
749            allocator.inlet_arc(or_else_secondary_name()),
750        ];
751        FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
752    }
753
754    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
755        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
756    }
757
758    fn stage_spec_with_ports(
759        &self,
760        _shape: &Self::Shape,
761        inlets: Vec<AnyInlet>,
762        outlets: Vec<AnyOutlet>,
763    ) -> StageSpec {
764        StageSpec::or_else(or_else_stage_name(), inlets, outlets)
765    }
766}
767
768#[derive(Clone, Debug)]
769pub struct Interleave<T: 'static> {
770    inputs: usize,
771    segment_size: usize,
772    eager_close: bool,
773    _marker: PhantomData<fn() -> T>,
774}
775
776impl<T: 'static> Interleave<T> {
777    #[must_use]
778    pub fn new(inputs: usize, segment_size: usize) -> Self {
779        Self::new_with_eager_close(inputs, segment_size, false)
780    }
781
782    #[must_use]
783    pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
784        assert!(
785            inputs > 1,
786            "interleave input count must be greater than one"
787        );
788        assert!(
789            segment_size > 0,
790            "interleave segment size must be greater than zero"
791        );
792        Self {
793            inputs,
794            segment_size,
795            eager_close,
796            _marker: PhantomData,
797        }
798    }
799}
800
801impl<T> GraphStage for Interleave<T>
802where
803    T: Clone + Send + 'static,
804{
805    type Shape = FanInShape<T, T>;
806
807    fn name(&self) -> &str {
808        "Interleave"
809    }
810
811    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
812        let inlets = (0..self.inputs)
813            .map(|index| allocator.inlet(format!("Interleave.in{index}")))
814            .collect();
815        FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
816    }
817
818    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
819        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
820    }
821
822    fn stage_spec_with_ports(
823        &self,
824        _shape: &Self::Shape,
825        inlets: Vec<AnyInlet>,
826        outlets: Vec<AnyOutlet>,
827    ) -> StageSpec {
828        StageSpec::interleave(
829            interleave_stage_name(),
830            inlets,
831            outlets,
832            self.segment_size,
833            self.eager_close,
834        )
835    }
836}
837
838#[derive(Clone, Debug)]
839pub struct MergePreferred<T: 'static> {
840    secondary_ports: usize,
841    _marker: PhantomData<fn() -> T>,
842}
843
844impl<T: 'static> MergePreferred<T> {
845    #[must_use]
846    pub fn new(secondary_ports: usize) -> Self {
847        assert!(
848            secondary_ports > 0,
849            "merge-preferred secondary input count must be greater than zero"
850        );
851        Self {
852            secondary_ports,
853            _marker: PhantomData,
854        }
855    }
856}
857
858impl<T> GraphStage for MergePreferred<T>
859where
860    T: Clone + Send + 'static,
861{
862    type Shape = MergePreferredShape<T>;
863
864    fn name(&self) -> &str {
865        "MergePreferred"
866    }
867
868    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
869        let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
870        let secondary = (0..self.secondary_ports)
871            .map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
872            .collect();
873        MergePreferredShape::new(
874            preferred,
875            secondary,
876            allocator.outlet_arc(merge_preferred_outlet_name()),
877        )
878    }
879
880    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
881        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
882    }
883
884    fn stage_spec_with_ports(
885        &self,
886        _shape: &Self::Shape,
887        inlets: Vec<AnyInlet>,
888        outlets: Vec<AnyOutlet>,
889    ) -> StageSpec {
890        StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
891    }
892}
893
894#[derive(Clone, Debug)]
895pub struct MergePrioritized<T: 'static> {
896    weights: Vec<usize>,
897    _marker: PhantomData<fn() -> T>,
898}
899
900impl<T: 'static> MergePrioritized<T> {
901    #[must_use]
902    pub fn new(weights: Vec<usize>) -> Self {
903        assert!(!weights.is_empty(), "prioritized merge must have inputs");
904        assert!(
905            weights.iter().all(|weight| *weight > 0),
906            "prioritized merge weights must be greater than zero"
907        );
908        Self {
909            weights,
910            _marker: PhantomData,
911        }
912    }
913}
914
915impl<T> GraphStage for MergePrioritized<T>
916where
917    T: Clone + Send + 'static,
918{
919    type Shape = FanInShape<T, T>;
920
921    fn name(&self) -> &str {
922        "MergePrioritized"
923    }
924
925    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
926        let inlets = (0..self.weights.len())
927            .map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
928            .collect();
929        FanInShape::new(
930            inlets,
931            allocator.outlet_arc(merge_prioritized_outlet_name()),
932        )
933    }
934
935    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
936        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
937    }
938
939    fn stage_spec_with_ports(
940        &self,
941        _shape: &Self::Shape,
942        inlets: Vec<AnyInlet>,
943        outlets: Vec<AnyOutlet>,
944    ) -> StageSpec {
945        StageSpec::merge_prioritized(
946            merge_prioritized_stage_name(),
947            inlets,
948            outlets,
949            self.weights.clone(),
950        )
951    }
952}
953
954#[derive(Clone, Debug)]
955pub struct Zip<Left: 'static, Right: 'static> {
956    _marker: PhantomData<fn() -> (Left, Right)>,
957}
958
959impl<Left: 'static, Right: 'static> Zip<Left, Right> {
960    #[must_use]
961    pub fn new() -> Self {
962        Self {
963            _marker: PhantomData,
964        }
965    }
966}
967
968impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
969    fn default() -> Self {
970        Self::new()
971    }
972}
973
974impl<Left, Right> GraphStage for Zip<Left, Right>
975where
976    Left: Clone + Send + 'static,
977    Right: Clone + Send + 'static,
978{
979    type Shape = ZipShape<Left, Right>;
980
981    fn name(&self) -> &str {
982        "Zip"
983    }
984
985    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
986        let first_id = next_port_id_block(3);
987        ZipShape::new(
988            Inlet::with_arc_name(first_id, zip_in0_name()),
989            Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
990            Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
991        )
992    }
993
994    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
995        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
996    }
997
998    fn stage_spec_with_ports(
999        &self,
1000        _shape: &Self::Shape,
1001        inlets: Vec<AnyInlet>,
1002        outlets: Vec<AnyOutlet>,
1003    ) -> StageSpec {
1004        let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
1005            let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
1006            let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
1007            Ok(datum((left, right)))
1008        });
1009        StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
1010    }
1011}
1012
1013#[derive(Clone, Debug)]
1014pub struct MergeSorted<T: 'static> {
1015    _marker: PhantomData<fn() -> T>,
1016}
1017
1018impl<T: 'static> MergeSorted<T> {
1019    #[must_use]
1020    pub fn new() -> Self {
1021        Self {
1022            _marker: PhantomData,
1023        }
1024    }
1025}
1026
1027impl<T: 'static> Default for MergeSorted<T> {
1028    fn default() -> Self {
1029        Self::new()
1030    }
1031}
1032
1033impl<T> GraphStage for MergeSorted<T>
1034where
1035    T: Clone + Ord + Send + 'static,
1036{
1037    type Shape = FanInShape<T, T>;
1038
1039    fn name(&self) -> &str {
1040        "MergeSorted"
1041    }
1042
1043    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1044        let inlets = vec![
1045            allocator.inlet("MergeSorted.in0"),
1046            allocator.inlet("MergeSorted.in1"),
1047        ];
1048        FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
1049    }
1050
1051    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1052        let compare = Arc::new(
1053            move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
1054                let a_t: &T = a
1055                    .as_any_ref()
1056                    .downcast_ref::<T>()
1057                    .expect("merge-sorted compare: wrong element type");
1058                let b_t: &T = b
1059                    .as_any_ref()
1060                    .downcast_ref::<T>()
1061                    .expect("merge-sorted compare: wrong element type");
1062                a_t.cmp(b_t)
1063            },
1064        );
1065        StageSpec::merge_sorted(
1066            Arc::from(self.name()),
1067            shape.inlets(),
1068            shape.outlets(),
1069            compare,
1070        )
1071    }
1072
1073    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1074        struct State<T> {
1075            left: VecDeque<T>,
1076            right: VecDeque<T>,
1077            left_closed: bool,
1078            right_closed: bool,
1079            pending: VecDeque<T>,
1080        }
1081
1082        impl<T> Default for State<T> {
1083            fn default() -> Self {
1084                Self {
1085                    left: VecDeque::new(),
1086                    right: VecDeque::new(),
1087                    left_closed: false,
1088                    right_closed: false,
1089                    pending: VecDeque::new(),
1090                }
1091            }
1092        }
1093
1094        fn maybe_queue_output<T>(state: &mut State<T>) -> bool
1095        where
1096            T: Clone + Ord,
1097        {
1098            let next = match (state.left.front(), state.right.front()) {
1099                (Some(left), Some(right)) => {
1100                    if left <= right {
1101                        state.left.pop_front()
1102                    } else {
1103                        state.right.pop_front()
1104                    }
1105                }
1106                (Some(_), None) if state.right_closed => state.left.pop_front(),
1107                (None, Some(_)) if state.left_closed => state.right.pop_front(),
1108                _ => None,
1109            };
1110            if let Some(value) = next {
1111                state.pending.push_back(value);
1112                true
1113            } else {
1114                false
1115            }
1116        }
1117
1118        fn maybe_complete<T>(
1119            logic: &mut GraphStageLogic,
1120            outlet: &Outlet<T>,
1121            state: &State<T>,
1122        ) -> StreamResult<()>
1123        where
1124            T: Clone + Send + 'static,
1125        {
1126            if state.left_closed
1127                && state.right_closed
1128                && state.left.is_empty()
1129                && state.right.is_empty()
1130                && state.pending.is_empty()
1131                && !logic.is_closed(outlet)
1132            {
1133                logic.complete(outlet)?;
1134            }
1135            Ok(())
1136        }
1137
1138        fn maybe_pull<T>(
1139            logic: &mut GraphStageLogic,
1140            left: &Inlet<T>,
1141            right: &Inlet<T>,
1142            state: &State<T>,
1143        ) -> StreamResult<()>
1144        where
1145            T: Clone + Send + 'static,
1146        {
1147            if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
1148                logic.pull(left)?;
1149            }
1150            if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
1151                logic.pull(right)?;
1152            }
1153            Ok(())
1154        }
1155
1156        fn maybe_drain<T>(
1157            logic: &mut GraphStageLogic,
1158            outlet: &Outlet<T>,
1159            state: &Arc<Mutex<State<T>>>,
1160        ) -> StreamResult<()>
1161        where
1162            T: Clone + Ord + Send + 'static,
1163        {
1164            let next = if logic.is_available(outlet) {
1165                state
1166                    .lock()
1167                    .expect("merge-sorted state poisoned")
1168                    .pending
1169                    .pop_front()
1170            } else {
1171                None
1172            };
1173            if let Some(value) = next {
1174                logic.push(outlet, value)?;
1175            }
1176            Ok(())
1177        }
1178
1179        struct In<T: 'static> {
1180            inlet_id: PortId,
1181            left: Inlet<T>,
1182            right: Inlet<T>,
1183            outlet: Outlet<T>,
1184            state: Arc<Mutex<State<T>>>,
1185        }
1186
1187        impl<T> InHandler for In<T>
1188        where
1189            T: Clone + Ord + Send + 'static,
1190        {
1191            fn on_push(
1192                &mut self,
1193                logic: &mut GraphStageLogic,
1194                _inlet: AnyInlet,
1195            ) -> StreamResult<()> {
1196                let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1197                    downcast_datum(value, "grab", || {
1198                        format!("inlet#{}", self.inlet_id.as_usize())
1199                    })
1200                })?;
1201                {
1202                    let mut state = self.state.lock().expect("merge-sorted state poisoned");
1203                    if self.inlet_id == self.left.id() {
1204                        state.left.push_back(value);
1205                    } else {
1206                        state.right.push_back(value);
1207                    }
1208                    while maybe_queue_output(&mut state) {}
1209                }
1210                maybe_drain(logic, &self.outlet, &self.state)?;
1211                let state = self.state.lock().expect("merge-sorted state poisoned");
1212                maybe_complete(logic, &self.outlet, &state)?;
1213                maybe_pull(logic, &self.left, &self.right, &state)
1214            }
1215
1216            fn on_upstream_finish(
1217                &mut self,
1218                logic: &mut GraphStageLogic,
1219                _inlet: AnyInlet,
1220            ) -> StreamResult<()> {
1221                {
1222                    let mut state = self.state.lock().expect("merge-sorted state poisoned");
1223                    if self.inlet_id == self.left.id() {
1224                        state.left_closed = true;
1225                    } else {
1226                        state.right_closed = true;
1227                    }
1228                    while maybe_queue_output(&mut state) {}
1229                }
1230                maybe_drain(logic, &self.outlet, &self.state)?;
1231                let state = self.state.lock().expect("merge-sorted state poisoned");
1232                maybe_complete(logic, &self.outlet, &state)?;
1233                maybe_pull(logic, &self.left, &self.right, &state)
1234            }
1235        }
1236
1237        struct Out<T: 'static> {
1238            left: Inlet<T>,
1239            right: Inlet<T>,
1240            outlet: Outlet<T>,
1241            state: Arc<Mutex<State<T>>>,
1242        }
1243
1244        impl<T> OutHandler for Out<T>
1245        where
1246            T: Clone + Ord + Send + 'static,
1247        {
1248            fn on_pull(
1249                &mut self,
1250                logic: &mut GraphStageLogic,
1251                _outlet: AnyOutlet,
1252            ) -> StreamResult<()> {
1253                maybe_drain(logic, &self.outlet, &self.state)?;
1254                let state = self.state.lock().expect("merge-sorted state poisoned");
1255                maybe_complete(logic, &self.outlet, &state)?;
1256                maybe_pull(logic, &self.left, &self.right, &state)
1257            }
1258        }
1259
1260        let state = Arc::new(Mutex::new(State::<T>::default()));
1261        let left = shape.inlet(0).expect("merge-sorted left inlet");
1262        let right = shape.inlet(1).expect("merge-sorted right inlet");
1263        let outlet = shape.outlet();
1264        let mut logic = GraphStageLogic::new(shape);
1265        logic
1266            .set_handler(
1267                &left,
1268                Box::new(In {
1269                    inlet_id: left.id(),
1270                    left: left.clone(),
1271                    right: right.clone(),
1272                    outlet: outlet.clone(),
1273                    state: Arc::clone(&state),
1274                }),
1275            )
1276            .unwrap();
1277        logic
1278            .set_handler(
1279                &right,
1280                Box::new(In {
1281                    inlet_id: right.id(),
1282                    left: left.clone(),
1283                    right: right.clone(),
1284                    outlet: outlet.clone(),
1285                    state: Arc::clone(&state),
1286                }),
1287            )
1288            .unwrap();
1289        logic
1290            .set_out_handler(
1291                &outlet.clone(),
1292                Box::new(Out {
1293                    left,
1294                    right,
1295                    outlet: outlet.clone(),
1296                    state,
1297                }),
1298            )
1299            .unwrap();
1300        logic
1301    }
1302}
1303
1304#[derive(Clone)]
1305pub struct MergeSequence<T: 'static> {
1306    inputs: usize,
1307    extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1308    _marker: PhantomData<fn() -> T>,
1309}
1310
1311impl<T: 'static> fmt::Debug for MergeSequence<T> {
1312    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1313        f.debug_struct("MergeSequence")
1314            .field("inputs", &self.inputs)
1315            .finish_non_exhaustive()
1316    }
1317}
1318
1319impl<T: 'static> MergeSequence<T> {
1320    #[must_use]
1321    pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
1322    where
1323        F: Fn(&T) -> u64 + Send + Sync + 'static,
1324    {
1325        assert!(
1326            inputs > 1,
1327            "merge sequence input count must be greater than one"
1328        );
1329        Self {
1330            inputs,
1331            extract_sequence: Arc::new(extract_sequence),
1332            _marker: PhantomData,
1333        }
1334    }
1335}
1336
1337impl<T> GraphStage for MergeSequence<T>
1338where
1339    T: Clone + Send + 'static,
1340{
1341    type Shape = FanInShape<T, T>;
1342
1343    fn name(&self) -> &str {
1344        "MergeSequence"
1345    }
1346
1347    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1348        let inlets = (0..self.inputs)
1349            .map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
1350            .collect();
1351        FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
1352    }
1353
1354    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1355        let extract = Arc::clone(&self.extract_sequence);
1356        let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
1357            let t: &T = dv
1358                .as_any_ref()
1359                .downcast_ref::<T>()
1360                .expect("merge-sequence extract: wrong element type");
1361            extract(t)
1362        });
1363        // Typed extractor: stored as `Arc<dyn Any + Send + Sync>` and
1364        // down-cast at plan time to `Arc<dyn Fn(&T) -> u64 + Send + Sync>`.
1365        let typed_extract_fn = Arc::clone(&self.extract_sequence);
1366        let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
1367        let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
1368        StageSpec::merge_sequence(
1369            Arc::from(self.name()),
1370            shape.inlets(),
1371            shape.outlets(),
1372            self.inputs,
1373            extract_sequence,
1374            typed_extract,
1375        )
1376    }
1377
1378    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1379        #[derive(Clone)]
1380        struct Pending<T> {
1381            sequence: u64,
1382            elem: T,
1383        }
1384
1385        struct State<T> {
1386            next_sequence: u64,
1387            pending: Vec<Pending<T>>,
1388            completed: usize,
1389            pending_output: VecDeque<T>,
1390        }
1391
1392        fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
1393        where
1394            T: Clone + Send + 'static,
1395        {
1396            while let Some(index) = state
1397                .pending
1398                .iter()
1399                .position(|item| item.sequence == state.next_sequence)
1400            {
1401                let item = state.pending.remove(index);
1402                if state
1403                    .pending
1404                    .iter()
1405                    .any(|other| other.sequence == state.next_sequence)
1406                {
1407                    return Err(StreamError::Failed(format!(
1408                        "duplicate sequence {} on merge sequence",
1409                        state.next_sequence
1410                    )));
1411                }
1412                state.pending_output.push_back(item.elem);
1413                state.next_sequence += 1;
1414            }
1415            Ok(())
1416        }
1417
1418        struct In<T: 'static> {
1419            inlet_id: PortId,
1420            inlet_index: usize,
1421            inlet: Inlet<T>,
1422            all_inlets: Vec<Inlet<T>>,
1423            outlet: Outlet<T>,
1424            extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1425            state: Arc<Mutex<State<T>>>,
1426        }
1427
1428        impl<T> InHandler for In<T>
1429        where
1430            T: Clone + Send + 'static,
1431        {
1432            fn on_push(
1433                &mut self,
1434                logic: &mut GraphStageLogic,
1435                _inlet: AnyInlet,
1436            ) -> StreamResult<()> {
1437                let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1438                    downcast_datum(value, "grab", || {
1439                        format!("inlet#{}", self.inlet_id.as_usize())
1440                    })
1441                })?;
1442                {
1443                    let mut state = self.state.lock().expect("merge-sequence state poisoned");
1444                    let sequence = (self.extract_sequence)(&elem);
1445                    if sequence < state.next_sequence {
1446                        return Err(StreamError::Failed(format!(
1447                            "sequence regression from {} to {} on port {}",
1448                            state.next_sequence, sequence, self.inlet_index
1449                        )));
1450                    }
1451                    state.pending.push(Pending { sequence, elem });
1452                    try_emit_pending(&mut state)?;
1453                }
1454                let next = if logic.is_available(&self.outlet) {
1455                    self.state
1456                        .lock()
1457                        .expect("merge-sequence state poisoned")
1458                        .pending_output
1459                        .pop_front()
1460                } else {
1461                    None
1462                };
1463                if let Some(value) = next {
1464                    logic.push(&self.outlet, value)?;
1465                }
1466                let state = self.state.lock().expect("merge-sequence state poisoned");
1467                if state.completed == self.all_inlets.len()
1468                    && state.pending.is_empty()
1469                    && state.pending_output.is_empty()
1470                {
1471                    logic.complete(&self.outlet)?;
1472                } else if logic.is_available(&self.outlet)
1473                    && state.pending_output.is_empty()
1474                    && state.pending.len() + state.completed == self.all_inlets.len()
1475                {
1476                    return Err(StreamError::Failed(format!(
1477                        "expected sequence {}, but all input ports have pushed or are complete",
1478                        state.next_sequence
1479                    )));
1480                }
1481                if !logic.has_been_pulled(&self.inlet) {
1482                    logic.pull(&self.inlet)?;
1483                }
1484                Ok(())
1485            }
1486
1487            fn on_upstream_finish(
1488                &mut self,
1489                logic: &mut GraphStageLogic,
1490                _inlet: AnyInlet,
1491            ) -> StreamResult<()> {
1492                {
1493                    let mut state = self.state.lock().expect("merge-sequence state poisoned");
1494                    state.completed += 1;
1495                }
1496                let state = self.state.lock().expect("merge-sequence state poisoned");
1497                if state.completed == self.all_inlets.len()
1498                    && state.pending.is_empty()
1499                    && state.pending_output.is_empty()
1500                {
1501                    logic.complete(&self.outlet)?;
1502                } else if logic.is_available(&self.outlet)
1503                    && state.pending_output.is_empty()
1504                    && state.pending.len() + state.completed == self.all_inlets.len()
1505                {
1506                    return Err(StreamError::Failed(format!(
1507                        "expected sequence {}, but all input ports have pushed or are complete",
1508                        state.next_sequence
1509                    )));
1510                }
1511                Ok(())
1512            }
1513        }
1514
1515        struct Out<T: 'static> {
1516            inlets: Vec<Inlet<T>>,
1517            outlet: Outlet<T>,
1518            state: Arc<Mutex<State<T>>>,
1519        }
1520
1521        impl<T> OutHandler for Out<T>
1522        where
1523            T: Clone + Send + 'static,
1524        {
1525            fn on_pull(
1526                &mut self,
1527                logic: &mut GraphStageLogic,
1528                _outlet: AnyOutlet,
1529            ) -> StreamResult<()> {
1530                let next = self
1531                    .state
1532                    .lock()
1533                    .expect("merge-sequence state poisoned")
1534                    .pending_output
1535                    .pop_front();
1536                if let Some(value) = next {
1537                    logic.push(&self.outlet, value)?;
1538                } else {
1539                    let state = self.state.lock().expect("merge-sequence state poisoned");
1540                    if state.completed == self.inlets.len() && state.pending.is_empty() {
1541                        logic.complete(&self.outlet)?;
1542                    } else if state.pending.len() + state.completed == self.inlets.len() {
1543                        return Err(StreamError::Failed(format!(
1544                            "expected sequence {}, but all input ports have pushed or are complete",
1545                            state.next_sequence
1546                        )));
1547                    }
1548                }
1549                for inlet in &self.inlets {
1550                    if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1551                        logic.pull(inlet)?;
1552                    }
1553                }
1554                Ok(())
1555            }
1556        }
1557
1558        let inlets = shape.inlets_vec();
1559        let outlet = shape.outlet();
1560        let state = Arc::new(Mutex::new(State {
1561            next_sequence: 0,
1562            pending: Vec::new(),
1563            completed: 0,
1564            pending_output: VecDeque::new(),
1565        }));
1566        let mut logic = GraphStageLogic::new(shape);
1567        for (index, inlet) in inlets.iter().cloned().enumerate() {
1568            logic
1569                .set_handler(
1570                    &inlet.clone(),
1571                    Box::new(In {
1572                        inlet_id: inlet.id(),
1573                        inlet_index: index,
1574                        inlet: inlet.clone(),
1575                        all_inlets: inlets.clone(),
1576                        outlet: outlet.clone(),
1577                        extract_sequence: Arc::clone(&self.extract_sequence),
1578                        state: Arc::clone(&state),
1579                    }),
1580                )
1581                .unwrap();
1582        }
1583        logic
1584            .set_out_handler(
1585                &outlet.clone(),
1586                Box::new(Out {
1587                    inlets,
1588                    outlet: outlet.clone(),
1589                    state,
1590                }),
1591            )
1592            .unwrap();
1593        logic
1594    }
1595}
1596
1597#[derive(Clone, Debug)]
1598pub struct MergeLatest<T: 'static> {
1599    inputs: usize,
1600    eager_complete: bool,
1601    _marker: PhantomData<fn() -> T>,
1602}
1603
1604impl<T: 'static> MergeLatest<T> {
1605    #[must_use]
1606    pub fn new(inputs: usize, eager_complete: bool) -> Self {
1607        assert!(
1608            inputs > 0,
1609            "merge-latest input count must be greater than zero"
1610        );
1611        Self {
1612            inputs,
1613            eager_complete,
1614            _marker: PhantomData,
1615        }
1616    }
1617}
1618
1619impl<T> GraphStage for MergeLatest<T>
1620where
1621    T: Clone + Send + 'static,
1622{
1623    type Shape = FanInShape<T, Vec<T>>;
1624
1625    fn name(&self) -> &str {
1626        "MergeLatest"
1627    }
1628
1629    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1630        let inlets = (0..self.inputs)
1631            .map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
1632            .collect();
1633        FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
1634    }
1635
1636    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1637        let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
1638            let snapshot: Vec<T> = values
1639                .iter()
1640                .map(|dv| {
1641                    dv.as_any_ref()
1642                        .downcast_ref::<T>()
1643                        .cloned()
1644                        .expect("merge-latest snapshot: wrong element type")
1645                })
1646                .collect();
1647            datum(snapshot)
1648        });
1649        // Typed snapshot: builds Vec<T> directly from &[Option<T>] without boxing.
1650        // The closure type is complex; suppress the lint for this local alias.
1651        #[allow(clippy::type_complexity)]
1652        let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
1653            Arc::new(move |slots: &[Option<T>]| {
1654                slots
1655                    .iter()
1656                    .map(|s| {
1657                        s.clone()
1658                            .expect("merge-latest typed snapshot: slot is None")
1659                    })
1660                    .collect()
1661            });
1662        let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
1663        StageSpec::merge_latest(
1664            Arc::from(self.name()),
1665            shape.inlets(),
1666            shape.outlets(),
1667            self.inputs,
1668            self.eager_complete,
1669            build_snapshot,
1670            typed_snapshot,
1671        )
1672    }
1673
1674    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1675        struct State<T> {
1676            latest: Vec<Option<T>>,
1677            seen: usize,
1678            completed: usize,
1679            pending: VecDeque<Vec<T>>,
1680            eager_complete: bool,
1681        }
1682
1683        struct In<T: 'static> {
1684            inlet_id: PortId,
1685            inlet_index: usize,
1686            inlet: Inlet<T>,
1687            all_inlets: Vec<Inlet<T>>,
1688            outlet: Outlet<Vec<T>>,
1689            state: Arc<Mutex<State<T>>>,
1690        }
1691
1692        impl<T> InHandler for In<T>
1693        where
1694            T: Clone + Send + 'static,
1695        {
1696            fn on_push(
1697                &mut self,
1698                logic: &mut GraphStageLogic,
1699                _inlet: AnyInlet,
1700            ) -> StreamResult<()> {
1701                let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1702                    downcast_datum(value, "grab", || {
1703                        format!("inlet#{}", self.inlet_id.as_usize())
1704                    })
1705                })?;
1706                {
1707                    let mut state = self.state.lock().expect("merge-latest state poisoned");
1708                    if state.latest[self.inlet_index].is_none() {
1709                        state.seen += 1;
1710                    }
1711                    state.latest[self.inlet_index] = Some(elem);
1712                    if state.seen == state.latest.len() {
1713                        let snapshot = state
1714                            .latest
1715                            .iter()
1716                            .map(|item| item.clone().expect("merge-latest seen"))
1717                            .collect();
1718                        state.pending.push_back(snapshot);
1719                    }
1720                }
1721                let next = if logic.is_available(&self.outlet) {
1722                    self.state
1723                        .lock()
1724                        .expect("merge-latest state poisoned")
1725                        .pending
1726                        .pop_front()
1727                } else {
1728                    None
1729                };
1730                if let Some(value) = next {
1731                    logic.push(&self.outlet, value)?;
1732                }
1733                if !logic.has_been_pulled(&self.inlet) {
1734                    logic.pull(&self.inlet)?;
1735                }
1736                Ok(())
1737            }
1738
1739            fn on_upstream_finish(
1740                &mut self,
1741                logic: &mut GraphStageLogic,
1742                _inlet: AnyInlet,
1743            ) -> StreamResult<()> {
1744                let state = {
1745                    let mut state = self.state.lock().expect("merge-latest state poisoned");
1746                    state.completed += 1;
1747                    (
1748                        state.completed == self.all_inlets.len(),
1749                        state.eager_complete,
1750                        state.pending.is_empty(),
1751                    )
1752                };
1753                if state.0 || (state.1 && state.2) {
1754                    logic.complete(&self.outlet)?;
1755                }
1756                Ok(())
1757            }
1758        }
1759
1760        struct Out<T: 'static> {
1761            inlets: Vec<Inlet<T>>,
1762            outlet: Outlet<Vec<T>>,
1763            state: Arc<Mutex<State<T>>>,
1764        }
1765
1766        impl<T> OutHandler for Out<T>
1767        where
1768            T: Clone + Send + 'static,
1769        {
1770            fn on_pull(
1771                &mut self,
1772                logic: &mut GraphStageLogic,
1773                _outlet: AnyOutlet,
1774            ) -> StreamResult<()> {
1775                let next = self
1776                    .state
1777                    .lock()
1778                    .expect("merge-latest state poisoned")
1779                    .pending
1780                    .pop_front();
1781                if let Some(value) = next {
1782                    logic.push(&self.outlet, value)?;
1783                } else {
1784                    let state = self.state.lock().expect("merge-latest state poisoned");
1785                    if state.completed == self.inlets.len()
1786                        || (state.eager_complete && state.completed > 0)
1787                    {
1788                        logic.complete(&self.outlet)?;
1789                    }
1790                }
1791                for inlet in &self.inlets {
1792                    if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1793                        logic.pull(inlet)?;
1794                    }
1795                }
1796                Ok(())
1797            }
1798        }
1799
1800        let inlets = shape.inlets_vec();
1801        let outlet = shape.outlet();
1802        let state = Arc::new(Mutex::new(State {
1803            latest: vec![None; inlets.len()],
1804            seen: 0,
1805            completed: 0,
1806            pending: VecDeque::new(),
1807            eager_complete: self.eager_complete,
1808        }));
1809        let mut logic = GraphStageLogic::new(shape);
1810        for (index, inlet) in inlets.iter().cloned().enumerate() {
1811            logic
1812                .set_handler(
1813                    &inlet.clone(),
1814                    Box::new(In {
1815                        inlet_id: inlet.id(),
1816                        inlet_index: index,
1817                        inlet: inlet.clone(),
1818                        all_inlets: inlets.clone(),
1819                        outlet: outlet.clone(),
1820                        state: Arc::clone(&state),
1821                    }),
1822                )
1823                .unwrap();
1824        }
1825        logic
1826            .set_out_handler(
1827                &outlet.clone(),
1828                Box::new(Out {
1829                    inlets,
1830                    outlet: outlet.clone(),
1831                    state,
1832                }),
1833            )
1834            .unwrap();
1835        logic
1836    }
1837}
1838
1839#[derive(Clone)]
1840pub struct Partition<T: 'static> {
1841    outputs: usize,
1842    partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1843    eager_cancel: bool,
1844    _marker: PhantomData<fn() -> T>,
1845}
1846
1847impl<T: 'static> fmt::Debug for Partition<T> {
1848    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1849        f.debug_struct("Partition")
1850            .field("outputs", &self.outputs)
1851            .field("eager_cancel", &self.eager_cancel)
1852            .finish_non_exhaustive()
1853    }
1854}
1855
1856impl<T: 'static> Partition<T> {
1857    #[must_use]
1858    pub fn new<F>(outputs: usize, partitioner: F) -> Self
1859    where
1860        F: Fn(&T) -> usize + Send + Sync + 'static,
1861    {
1862        Self::new_with_eager_cancel(outputs, partitioner, false)
1863    }
1864
1865    #[must_use]
1866    pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
1867    where
1868        F: Fn(&T) -> usize + Send + Sync + 'static,
1869    {
1870        assert!(
1871            outputs > 0,
1872            "partition output count must be greater than zero"
1873        );
1874        Self {
1875            outputs,
1876            partitioner: Arc::new(partitioner),
1877            eager_cancel,
1878            _marker: PhantomData,
1879        }
1880    }
1881}
1882
1883impl<T> GraphStage for Partition<T>
1884where
1885    T: Clone + Send + 'static,
1886{
1887    type Shape = FanOutShape<T, T>;
1888
1889    fn name(&self) -> &str {
1890        "Partition"
1891    }
1892
1893    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1894        let inlet = allocator.inlet("Partition.in");
1895        let outlets = (0..self.outputs)
1896            .map(|index| allocator.outlet(format!("Partition.out{index}")))
1897            .collect();
1898        FanOutShape::new(inlet, outlets)
1899    }
1900
1901    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1902        let partitioner_clone = Arc::clone(&self.partitioner);
1903        let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
1904            let t: &T = dv
1905                .as_any_ref()
1906                .downcast_ref::<T>()
1907                .expect("partition: wrong element type");
1908            partitioner_clone(t)
1909        });
1910        let typed_partitioner_fn: Arc<dyn Fn(&T) -> usize + Send + Sync> =
1911            Arc::clone(&self.partitioner);
1912        let typed_partitioner: Arc<StageTypedPartitionFn> = Arc::new(typed_partitioner_fn);
1913        StageSpec::partition(
1914            Arc::from(self.name()),
1915            shape.inlets(),
1916            shape.outlets(),
1917            self.outputs,
1918            partitioner,
1919            typed_partitioner,
1920            self.eager_cancel,
1921        )
1922    }
1923
1924    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1925        struct State<T> {
1926            pending: Option<(usize, T)>,
1927            upstream_closed: bool,
1928            live_outlets: usize,
1929            cancelled: Vec<bool>,
1930            eager_cancel: bool,
1931        }
1932
1933        fn any_live_demand<T>(
1934            logic: &GraphStageLogic,
1935            outlets: &[Outlet<T>],
1936            cancelled: &[bool],
1937        ) -> bool
1938        where
1939            T: Clone + Send + 'static,
1940        {
1941            outlets
1942                .iter()
1943                .enumerate()
1944                .any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
1945        }
1946
1947        struct In<T: 'static> {
1948            inlet_id: PortId,
1949            inlet: Inlet<T>,
1950            outlets: Vec<Outlet<T>>,
1951            partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1952            state: Arc<Mutex<State<T>>>,
1953        }
1954
1955        impl<T> InHandler for In<T>
1956        where
1957            T: Clone + Send + 'static,
1958        {
1959            fn on_push(
1960                &mut self,
1961                logic: &mut GraphStageLogic,
1962                _inlet: AnyInlet,
1963            ) -> StreamResult<()> {
1964                let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1965                    downcast_datum(value, "grab", || {
1966                        format!("inlet#{}", self.inlet_id.as_usize())
1967                    })
1968                })?;
1969                let idx = (self.partitioner)(&item);
1970                if idx >= self.outlets.len() {
1971                    return Err(StreamError::Failed(format!(
1972                        "partitioner returned out-of-bounds index {idx} for {} outputs",
1973                        self.outlets.len()
1974                    )));
1975                }
1976                let mut pull_again = false;
1977                {
1978                    let mut state = self.state.lock().expect("partition state poisoned");
1979                    if state.cancelled[idx] {
1980                        pull_again = !state.upstream_closed
1981                            && any_live_demand(logic, &self.outlets, &state.cancelled);
1982                    } else if logic.is_available(&self.outlets[idx]) {
1983                        logic.push(&self.outlets[idx], item)?;
1984                        pull_again = !state.upstream_closed
1985                            && any_live_demand(logic, &self.outlets, &state.cancelled);
1986                    } else {
1987                        state.pending = Some((idx, item));
1988                    }
1989                }
1990                if pull_again && !logic.has_been_pulled(&self.inlet) {
1991                    logic.pull(&self.inlet)?;
1992                }
1993                Ok(())
1994            }
1995
1996            fn on_upstream_finish(
1997                &mut self,
1998                logic: &mut GraphStageLogic,
1999                _inlet: AnyInlet,
2000            ) -> StreamResult<()> {
2001                let complete_now = {
2002                    let mut state = self.state.lock().expect("partition state poisoned");
2003                    state.upstream_closed = true;
2004                    state.pending.is_none()
2005                };
2006                if complete_now {
2007                    for outlet in &self.outlets {
2008                        if !logic.is_closed(outlet) {
2009                            logic.complete(outlet)?;
2010                        }
2011                    }
2012                }
2013                Ok(())
2014            }
2015        }
2016
2017        struct Out<T: 'static> {
2018            index: usize,
2019            inlet: Inlet<T>,
2020            outlets: Vec<Outlet<T>>,
2021            state: Arc<Mutex<State<T>>>,
2022        }
2023
2024        impl<T> OutHandler for Out<T>
2025        where
2026            T: Clone + Send + 'static,
2027        {
2028            fn on_pull(
2029                &mut self,
2030                logic: &mut GraphStageLogic,
2031                _outlet: AnyOutlet,
2032            ) -> StreamResult<()> {
2033                let mut complete_now = false;
2034                let pending = {
2035                    let mut state = self.state.lock().expect("partition state poisoned");
2036                    if let Some((idx, _)) = &state.pending
2037                        && *idx == self.index
2038                    {
2039                        state.pending.take()
2040                    } else {
2041                        None
2042                    }
2043                };
2044                if let Some((_, item)) = pending {
2045                    logic.push(&self.outlets[self.index], item)?;
2046                    let state = self.state.lock().expect("partition state poisoned");
2047                    if state.upstream_closed {
2048                        complete_now = true;
2049                    } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2050                        && !logic.has_been_pulled(&self.inlet)
2051                    {
2052                        logic.pull(&self.inlet)?;
2053                    }
2054                } else {
2055                    let state = self.state.lock().expect("partition state poisoned");
2056                    if state.upstream_closed {
2057                        complete_now = true;
2058                    } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2059                        && !logic.has_been_pulled(&self.inlet)
2060                    {
2061                        logic.pull(&self.inlet)?;
2062                    }
2063                }
2064                if complete_now {
2065                    for outlet in &self.outlets {
2066                        if !logic.is_closed(outlet) {
2067                            logic.complete(outlet)?;
2068                        }
2069                    }
2070                }
2071                Ok(())
2072            }
2073
2074            fn on_downstream_finish(
2075                &mut self,
2076                logic: &mut GraphStageLogic,
2077                _outlet: AnyOutlet,
2078            ) -> StreamResult<()> {
2079                let (cancel_stage, clear_pending) = {
2080                    let mut state = self.state.lock().expect("partition state poisoned");
2081                    if state.cancelled[self.index] {
2082                        return Ok(());
2083                    }
2084                    state.cancelled[self.index] = true;
2085                    state.live_outlets -= 1;
2086                    let clear_pending = state
2087                        .pending
2088                        .as_ref()
2089                        .is_some_and(|(idx, _)| *idx == self.index);
2090                    let cancel_stage = state.eager_cancel || state.live_outlets == 0;
2091                    if clear_pending {
2092                        state.pending = None;
2093                    }
2094                    (cancel_stage, clear_pending)
2095                };
2096                if cancel_stage {
2097                    logic.complete_stage()?;
2098                } else if clear_pending
2099                    && !logic.has_been_pulled(&self.inlet)
2100                    && !logic.is_closed(&self.inlet)
2101                {
2102                    let state = self.state.lock().expect("partition state poisoned");
2103                    if any_live_demand(logic, &self.outlets, &state.cancelled) {
2104                        logic.pull(&self.inlet)?;
2105                    }
2106                }
2107                Ok(())
2108            }
2109        }
2110
2111        let inlet = shape.inlet();
2112        let outlets = shape.outlets_vec();
2113        let state = Arc::new(Mutex::new(State {
2114            pending: None,
2115            upstream_closed: false,
2116            live_outlets: outlets.len(),
2117            cancelled: vec![false; outlets.len()],
2118            eager_cancel: self.eager_cancel,
2119        }));
2120        let mut logic = GraphStageLogic::new(shape);
2121        logic
2122            .set_handler(
2123                &inlet,
2124                Box::new(In {
2125                    inlet_id: inlet.id(),
2126                    inlet: inlet.clone(),
2127                    outlets: outlets.clone(),
2128                    partitioner: Arc::clone(&self.partitioner),
2129                    state: Arc::clone(&state),
2130                }),
2131            )
2132            .unwrap();
2133        for (index, outlet) in outlets.iter().cloned().enumerate() {
2134            logic
2135                .set_out_handler(
2136                    &outlet,
2137                    Box::new(Out {
2138                        index,
2139                        inlet: inlet.clone(),
2140                        outlets: outlets.clone(),
2141                        state: Arc::clone(&state),
2142                    }),
2143                )
2144                .unwrap();
2145        }
2146        logic
2147    }
2148}
2149
2150#[derive(Clone, Debug)]
2151pub struct Unzip<A: 'static, B: 'static> {
2152    _marker: PhantomData<fn() -> (A, B)>,
2153}
2154
2155impl<A: 'static, B: 'static> Unzip<A, B> {
2156    #[must_use]
2157    pub fn new() -> Self {
2158        Self {
2159            _marker: PhantomData,
2160        }
2161    }
2162}
2163
2164impl<A: 'static, B: 'static> Default for Unzip<A, B> {
2165    fn default() -> Self {
2166        Self::new()
2167    }
2168}
2169
2170impl<A, B> GraphStage for Unzip<A, B>
2171where
2172    A: Clone + Send + 'static,
2173    B: Clone + Send + 'static,
2174{
2175    type Shape = FanOutShape2<(A, B), A, B>;
2176
2177    fn name(&self) -> &str {
2178        "Unzip"
2179    }
2180
2181    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2182        FanOutShape2::new(
2183            allocator.inlet("Unzip.in"),
2184            allocator.outlet("Unzip.out0"),
2185            allocator.outlet("Unzip.out1"),
2186        )
2187    }
2188
2189    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2190        let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
2191            let pair: (A, B) =
2192                downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
2193            (datum(pair.0), datum(pair.1))
2194        });
2195        // Typed split: `|(a, b): (A, B)| -> (A, B)`.  Stored as an opaque
2196        // `Arc<StageTypedUnzipFn>` and down-cast at plan time.
2197        // The `Arc<dyn Fn((A, B)) -> (A, B)>` wrapper is intentionally verbose;
2198        // suppress the type_complexity lint for this one binding.
2199        #[allow(clippy::type_complexity)]
2200        let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
2201            Arc::new(|pair: (A, B)| pair);
2202        let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2203        StageSpec::unzip(
2204            Arc::from(self.name()),
2205            shape.inlets(),
2206            shape.outlets(),
2207            split,
2208            typed_split,
2209        )
2210    }
2211
2212    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2213        UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
2214    }
2215}
2216
2217#[derive(Clone)]
2218pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
2219    split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2220    _marker: PhantomData<fn(In) -> (Out0, Out1)>,
2221}
2222
2223impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
2224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2225        f.debug_struct("UnzipWith").finish_non_exhaustive()
2226    }
2227}
2228
2229impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
2230    #[must_use]
2231    pub fn new<F>(split: F) -> Self
2232    where
2233        F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
2234    {
2235        Self {
2236            split: Arc::new(split),
2237            _marker: PhantomData,
2238        }
2239    }
2240}
2241
2242impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
2243where
2244    In: Clone + Send + 'static,
2245    Out0: Clone + Send + 'static,
2246    Out1: Clone + Send + 'static,
2247{
2248    type Shape = FanOutShape2<In, Out0, Out1>;
2249
2250    fn name(&self) -> &str {
2251        "UnzipWith"
2252    }
2253
2254    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2255        FanOutShape2::new(
2256            allocator.inlet("UnzipWith.in"),
2257            allocator.outlet("UnzipWith.out0"),
2258            allocator.outlet("UnzipWith.out1"),
2259        )
2260    }
2261
2262    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2263        let split_fn = Arc::clone(&self.split);
2264        let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
2265            let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
2266                .expect("unzip-with: wrong element type");
2267            let (out0, out1) = split_fn(value);
2268            (datum(out0), datum(out1))
2269        });
2270        // Typed split: `Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>`.
2271        // Stored as opaque `Arc<StageTypedUnzipFn>` and down-cast at plan time.
2272        let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
2273        let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2274        StageSpec::unzip(
2275            Arc::from(self.name()),
2276            shape.inlets(),
2277            shape.outlets(),
2278            split,
2279            typed_split,
2280        )
2281    }
2282
2283    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2284        struct State {
2285            left_open: bool,
2286            right_open: bool,
2287            upstream_closed: bool,
2288        }
2289
2290        struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
2291            inlet_id: PortId,
2292            inlet: Inlet<In>,
2293            out0: Outlet<Out0>,
2294            out1: Outlet<Out1>,
2295            split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2296            state: Arc<Mutex<State>>,
2297        }
2298
2299        impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
2300        where
2301            In: Clone + Send + 'static,
2302            Out0: Clone + Send + 'static,
2303            Out1: Clone + Send + 'static,
2304        {
2305            fn on_push(
2306                &mut self,
2307                logic: &mut GraphStageLogic,
2308                _inlet: AnyInlet,
2309            ) -> StreamResult<()> {
2310                let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
2311                    downcast_datum(value, "grab", || {
2312                        format!("inlet#{}", self.inlet_id.as_usize())
2313                    })
2314                })?;
2315                let (left, right) = (self.split)(value);
2316                let state = self.state.lock().expect("unzip-with state poisoned");
2317                if state.left_open {
2318                    logic.push(&self.out0, left)?;
2319                }
2320                if state.right_open {
2321                    logic.push(&self.out1, right)?;
2322                }
2323                drop(state);
2324                let state = self.state.lock().expect("unzip-with state poisoned");
2325                let left_ready = !state.left_open || logic.is_available(&self.out0);
2326                let right_ready = !state.right_open || logic.is_available(&self.out1);
2327                if (state.left_open || state.right_open)
2328                    && left_ready
2329                    && right_ready
2330                    && !logic.has_been_pulled(&self.inlet)
2331                {
2332                    logic.pull(&self.inlet)?;
2333                }
2334                Ok(())
2335            }
2336
2337            fn on_upstream_finish(
2338                &mut self,
2339                logic: &mut GraphStageLogic,
2340                _inlet: AnyInlet,
2341            ) -> StreamResult<()> {
2342                self.state
2343                    .lock()
2344                    .expect("unzip-with state poisoned")
2345                    .upstream_closed = true;
2346                if !logic.is_closed(&self.out0) {
2347                    logic.complete(&self.out0)?;
2348                }
2349                if !logic.is_closed(&self.out1) {
2350                    logic.complete(&self.out1)?;
2351                }
2352                Ok(())
2353            }
2354        }
2355
2356        struct Out<In: 'static, Out0: 'static, Out1: 'static> {
2357            is_left: bool,
2358            inlet: Inlet<In>,
2359            out0: Outlet<Out0>,
2360            out1: Outlet<Out1>,
2361            state: Arc<Mutex<State>>,
2362        }
2363
2364        impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
2365        where
2366            In: Clone + Send + 'static,
2367            Out0: Clone + Send + 'static,
2368            Out1: Clone + Send + 'static,
2369        {
2370            fn on_pull(
2371                &mut self,
2372                logic: &mut GraphStageLogic,
2373                _outlet: AnyOutlet,
2374            ) -> StreamResult<()> {
2375                let state = self.state.lock().expect("unzip-with state poisoned");
2376                let left_ready = !state.left_open || logic.is_available(&self.out0);
2377                let right_ready = !state.right_open || logic.is_available(&self.out1);
2378                if state.upstream_closed {
2379                    drop(state);
2380                    if !logic.is_closed(&self.out0) {
2381                        logic.complete(&self.out0)?;
2382                    }
2383                    if !logic.is_closed(&self.out1) {
2384                        logic.complete(&self.out1)?;
2385                    }
2386                } else if (state.left_open || state.right_open)
2387                    && left_ready
2388                    && right_ready
2389                    && !logic.has_been_pulled(&self.inlet)
2390                {
2391                    drop(state);
2392                    logic.pull(&self.inlet)?;
2393                }
2394                Ok(())
2395            }
2396
2397            fn on_downstream_finish(
2398                &mut self,
2399                logic: &mut GraphStageLogic,
2400                _outlet: AnyOutlet,
2401            ) -> StreamResult<()> {
2402                let mut state = self.state.lock().expect("unzip-with state poisoned");
2403                if self.is_left {
2404                    state.left_open = false;
2405                } else {
2406                    state.right_open = false;
2407                }
2408                if !state.left_open && !state.right_open {
2409                    logic.complete_stage()?;
2410                    return Ok(());
2411                }
2412                let left_ready = !state.left_open || logic.is_available(&self.out0);
2413                let right_ready = !state.right_open || logic.is_available(&self.out1);
2414                if !state.upstream_closed
2415                    && (state.left_open || state.right_open)
2416                    && left_ready
2417                    && right_ready
2418                    && !logic.has_been_pulled(&self.inlet)
2419                {
2420                    logic.pull(&self.inlet)?;
2421                }
2422                Ok(())
2423            }
2424        }
2425
2426        let inlet = shape.inlet();
2427        let out0 = shape.out0();
2428        let out1 = shape.out1();
2429        let state = Arc::new(Mutex::new(State {
2430            left_open: true,
2431            right_open: true,
2432            upstream_closed: false,
2433        }));
2434        let mut logic = GraphStageLogic::new(shape);
2435        logic
2436            .set_handler(
2437                &inlet,
2438                Box::new(InHandlerState {
2439                    inlet_id: inlet.id(),
2440                    inlet: inlet.clone(),
2441                    out0: out0.clone(),
2442                    out1: out1.clone(),
2443                    split: Arc::clone(&self.split),
2444                    state: Arc::clone(&state),
2445                }),
2446            )
2447            .unwrap();
2448        logic
2449            .set_out_handler(
2450                &out0,
2451                Box::new(Out {
2452                    is_left: true,
2453                    inlet: inlet.clone(),
2454                    out0: out0.clone(),
2455                    out1: out1.clone(),
2456                    state: Arc::clone(&state),
2457                }),
2458            )
2459            .unwrap();
2460        logic
2461            .set_out_handler(
2462                &out1.clone(),
2463                Box::new(Out {
2464                    is_left: false,
2465                    inlet: inlet.clone(),
2466                    out0: out0.clone(),
2467                    out1: out1.clone(),
2468                    state,
2469                }),
2470            )
2471            .unwrap();
2472        logic
2473    }
2474}
2475
2476#[derive(Clone, Debug)]
2477pub struct AsyncBoundary<T: 'static> {
2478    _marker: PhantomData<fn() -> T>,
2479}
2480
2481impl<T: 'static> AsyncBoundary<T> {
2482    #[must_use]
2483    pub fn new() -> Self {
2484        Self {
2485            _marker: PhantomData,
2486        }
2487    }
2488}
2489
2490impl<T: 'static> Default for AsyncBoundary<T> {
2491    fn default() -> Self {
2492        Self::new()
2493    }
2494}
2495
2496impl<T> GraphStage for AsyncBoundary<T>
2497where
2498    T: Clone + Send + 'static,
2499{
2500    type Shape = FlowShape<T, T>;
2501
2502    fn name(&self) -> &str {
2503        "AsyncBoundary"
2504    }
2505
2506    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
2507        let first_id = next_port_id_block(2);
2508        FlowShape::new(
2509            Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
2510            Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
2511        )
2512    }
2513
2514    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2515        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
2516    }
2517
2518    fn stage_spec_with_ports(
2519        &self,
2520        _shape: &Self::Shape,
2521        inlets: Vec<AnyInlet>,
2522        outlets: Vec<AnyOutlet>,
2523    ) -> StageSpec {
2524        StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
2525    }
2526}