Skip to main content

datum/graph/
junctions.rs

1use super::*;
2
3#[derive(Clone, Debug)]
4pub struct Identity<T: 'static> {
5    _marker: PhantomData<fn() -> T>,
6}
7
8impl<T: 'static> Identity<T> {
9    #[must_use]
10    pub fn new() -> Self {
11        Self {
12            _marker: PhantomData,
13        }
14    }
15}
16
17impl<T: 'static> Default for Identity<T> {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl<T> GraphStage for Identity<T>
24where
25    T: Clone + Send + 'static,
26{
27    type Shape = FlowShape<T, T>;
28
29    fn name(&self) -> &str {
30        "Identity"
31    }
32
33    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
34        let first_id = next_port_id_block(2);
35        FlowShape::new(
36            Inlet::with_arc_name(first_id, identity_inlet_name()),
37            Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
38        )
39    }
40
41    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
42        StageSpec::identity(shape.inlets(), shape.outlets())
43    }
44
45    fn stage_spec_with_ports(
46        &self,
47        _shape: &Self::Shape,
48        inlets: Vec<AnyInlet>,
49        outlets: Vec<AnyOutlet>,
50    ) -> StageSpec {
51        StageSpec::identity(inlets, outlets)
52    }
53}
54
55#[derive(Clone)]
56pub struct MapStage<In: 'static, Out: 'static> {
57    f: Arc<dyn Fn(In) -> Out + Send + Sync>,
58    _marker: PhantomData<fn(In) -> Out>,
59}
60
61impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        f.debug_struct("MapStage")
64            .field("name", &"Map")
65            .finish_non_exhaustive()
66    }
67}
68
69impl<In: 'static, Out: 'static> MapStage<In, Out> {
70    #[must_use]
71    pub fn new<F>(f: F) -> Self
72    where
73        F: Fn(In) -> Out + Send + Sync + 'static,
74    {
75        Self {
76            f: Arc::new(f),
77            _marker: PhantomData,
78        }
79    }
80}
81
82impl<In, Out> GraphStage for MapStage<In, Out>
83where
84    In: Clone + Send + 'static,
85    Out: Clone + Send + 'static,
86{
87    type Shape = FlowShape<In, Out>;
88
89    fn name(&self) -> &str {
90        "Map"
91    }
92
93    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
94        let first_id = next_port_id_block(2);
95        FlowShape::new(
96            Inlet::with_arc_name(first_id, map_inlet_name()),
97            Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
98        )
99    }
100
101    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
102        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
103    }
104
105    fn stage_spec_with_ports(
106        &self,
107        _shape: &Self::Shape,
108        inlets: Vec<AnyInlet>,
109        outlets: Vec<AnyOutlet>,
110    ) -> StageSpec {
111        let f = Arc::clone(&self.f);
112        let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
113        let mapper = Arc::new(move |value: DatumValue| {
114            let value: In = downcast_datum(value, "map", || "Map.in")?;
115            Ok(datum(f(value)))
116        });
117        StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
118    }
119}
120
121#[derive(Clone, Debug)]
122pub struct Broadcast<T: 'static> {
123    outputs: usize,
124    _marker: PhantomData<fn() -> T>,
125}
126
127impl<T: 'static> Broadcast<T> {
128    #[must_use]
129    pub fn new(outputs: usize) -> Self {
130        assert!(
131            outputs > 0,
132            "broadcast output count must be greater than zero"
133        );
134        Self {
135            outputs,
136            _marker: PhantomData,
137        }
138    }
139}
140
141impl<T> GraphStage for Broadcast<T>
142where
143    T: Clone + Send + 'static,
144{
145    type Shape = FanOutShape<T, T>;
146
147    fn name(&self) -> &str {
148        "Broadcast"
149    }
150
151    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
152        let inlet = allocator.inlet_arc(broadcast_inlet_name());
153        let outlets = (0..self.outputs)
154            .map(|index| allocator.outlet(format!("Broadcast.out{index}")))
155            .collect();
156        FanOutShape::new(inlet, outlets)
157    }
158
159    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
160        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
161    }
162
163    fn stage_spec_with_ports(
164        &self,
165        _shape: &Self::Shape,
166        inlets: Vec<AnyInlet>,
167        outlets: Vec<AnyOutlet>,
168    ) -> StageSpec {
169        StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
170    }
171}
172
173#[derive(Clone, Debug)]
174pub struct Balance<T: 'static> {
175    outputs: usize,
176    _marker: PhantomData<fn() -> T>,
177}
178
179impl<T: 'static> Balance<T> {
180    #[must_use]
181    pub fn new(outputs: usize) -> Self {
182        assert!(
183            outputs > 0,
184            "balance output count must be greater than zero"
185        );
186        Self {
187            outputs,
188            _marker: PhantomData,
189        }
190    }
191}
192
193impl<T> GraphStage for Balance<T>
194where
195    T: Clone + Send + 'static,
196{
197    type Shape = FanOutShape<T, T>;
198
199    fn name(&self) -> &str {
200        "Balance"
201    }
202
203    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
204        let inlet = allocator.inlet_arc(balance_inlet_name());
205        let outlets = (0..self.outputs)
206            .map(|index| allocator.outlet(format!("Balance.out{index}")))
207            .collect();
208        FanOutShape::new(inlet, outlets)
209    }
210
211    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
212        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
213    }
214
215    fn stage_spec_with_ports(
216        &self,
217        _shape: &Self::Shape,
218        inlets: Vec<AnyInlet>,
219        outlets: Vec<AnyOutlet>,
220    ) -> StageSpec {
221        StageSpec::balance(balance_stage_name(), inlets, outlets)
222    }
223}
224
225#[derive(Clone, Debug)]
226pub struct Merge<T: 'static> {
227    inputs: usize,
228    _marker: PhantomData<fn() -> T>,
229}
230
231impl<T: 'static> Merge<T> {
232    #[must_use]
233    pub fn new(inputs: usize) -> Self {
234        assert!(inputs > 0, "merge input count must be greater than zero");
235        Self {
236            inputs,
237            _marker: PhantomData,
238        }
239    }
240}
241
242impl<T> GraphStage for Merge<T>
243where
244    T: Clone + Send + 'static,
245{
246    type Shape = FanInShape<T, T>;
247
248    fn name(&self) -> &str {
249        "Merge"
250    }
251
252    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
253        let inlets = (0..self.inputs)
254            .map(|index| allocator.inlet(format!("Merge.in{index}")))
255            .collect();
256        FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
257    }
258
259    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
260        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
261    }
262
263    fn stage_spec_with_ports(
264        &self,
265        _shape: &Self::Shape,
266        inlets: Vec<AnyInlet>,
267        outlets: Vec<AnyOutlet>,
268    ) -> StageSpec {
269        StageSpec::merge(merge_stage_name(), inlets, outlets)
270    }
271}
272
273#[derive(Clone, Debug)]
274pub struct Concat<T: 'static> {
275    inputs: usize,
276    _marker: PhantomData<fn() -> T>,
277}
278
279impl<T: 'static> Concat<T> {
280    #[must_use]
281    pub fn new(inputs: usize) -> Self {
282        assert!(inputs > 1, "concat input count must be greater than one");
283        Self {
284            inputs,
285            _marker: PhantomData,
286        }
287    }
288}
289
290impl<T> GraphStage for Concat<T>
291where
292    T: Clone + Send + 'static,
293{
294    type Shape = FanInShape<T, T>;
295
296    fn name(&self) -> &str {
297        "Concat"
298    }
299
300    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
301        let inlets = (0..self.inputs)
302            .map(|index| allocator.inlet(format!("Concat.in{index}")))
303            .collect();
304        FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
305    }
306
307    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
308        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
309    }
310
311    fn stage_spec_with_ports(
312        &self,
313        _shape: &Self::Shape,
314        inlets: Vec<AnyInlet>,
315        outlets: Vec<AnyOutlet>,
316    ) -> StageSpec {
317        StageSpec::concat(concat_stage_name(), inlets, outlets)
318    }
319}
320
321#[derive(Clone, Debug)]
322pub struct OrElse<T: 'static> {
323    _marker: PhantomData<fn() -> T>,
324}
325
326impl<T: 'static> OrElse<T> {
327    #[must_use]
328    pub fn new() -> Self {
329        Self {
330            _marker: PhantomData,
331        }
332    }
333}
334
335impl<T: 'static> Default for OrElse<T> {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl<T> GraphStage for OrElse<T>
342where
343    T: Clone + Send + 'static,
344{
345    type Shape = FanInShape<T, T>;
346
347    fn name(&self) -> &str {
348        "OrElse"
349    }
350
351    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
352        let inlets = vec![
353            allocator.inlet_arc(or_else_primary_name()),
354            allocator.inlet_arc(or_else_secondary_name()),
355        ];
356        FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
357    }
358
359    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
360        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
361    }
362
363    fn stage_spec_with_ports(
364        &self,
365        _shape: &Self::Shape,
366        inlets: Vec<AnyInlet>,
367        outlets: Vec<AnyOutlet>,
368    ) -> StageSpec {
369        StageSpec::or_else(or_else_stage_name(), inlets, outlets)
370    }
371}
372
373#[derive(Clone, Debug)]
374pub struct Interleave<T: 'static> {
375    inputs: usize,
376    segment_size: usize,
377    eager_close: bool,
378    _marker: PhantomData<fn() -> T>,
379}
380
381impl<T: 'static> Interleave<T> {
382    #[must_use]
383    pub fn new(inputs: usize, segment_size: usize) -> Self {
384        Self::new_with_eager_close(inputs, segment_size, false)
385    }
386
387    #[must_use]
388    pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
389        assert!(
390            inputs > 1,
391            "interleave input count must be greater than one"
392        );
393        assert!(
394            segment_size > 0,
395            "interleave segment size must be greater than zero"
396        );
397        Self {
398            inputs,
399            segment_size,
400            eager_close,
401            _marker: PhantomData,
402        }
403    }
404}
405
406impl<T> GraphStage for Interleave<T>
407where
408    T: Clone + Send + 'static,
409{
410    type Shape = FanInShape<T, T>;
411
412    fn name(&self) -> &str {
413        "Interleave"
414    }
415
416    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
417        let inlets = (0..self.inputs)
418            .map(|index| allocator.inlet(format!("Interleave.in{index}")))
419            .collect();
420        FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
421    }
422
423    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
424        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
425    }
426
427    fn stage_spec_with_ports(
428        &self,
429        _shape: &Self::Shape,
430        inlets: Vec<AnyInlet>,
431        outlets: Vec<AnyOutlet>,
432    ) -> StageSpec {
433        StageSpec::interleave(
434            interleave_stage_name(),
435            inlets,
436            outlets,
437            self.segment_size,
438            self.eager_close,
439        )
440    }
441}
442
443#[derive(Clone, Debug)]
444pub struct MergePreferred<T: 'static> {
445    secondary_ports: usize,
446    _marker: PhantomData<fn() -> T>,
447}
448
449impl<T: 'static> MergePreferred<T> {
450    #[must_use]
451    pub fn new(secondary_ports: usize) -> Self {
452        assert!(
453            secondary_ports > 0,
454            "merge-preferred secondary input count must be greater than zero"
455        );
456        Self {
457            secondary_ports,
458            _marker: PhantomData,
459        }
460    }
461}
462
463impl<T> GraphStage for MergePreferred<T>
464where
465    T: Clone + Send + 'static,
466{
467    type Shape = MergePreferredShape<T>;
468
469    fn name(&self) -> &str {
470        "MergePreferred"
471    }
472
473    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
474        let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
475        let secondary = (0..self.secondary_ports)
476            .map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
477            .collect();
478        MergePreferredShape::new(
479            preferred,
480            secondary,
481            allocator.outlet_arc(merge_preferred_outlet_name()),
482        )
483    }
484
485    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
486        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
487    }
488
489    fn stage_spec_with_ports(
490        &self,
491        _shape: &Self::Shape,
492        inlets: Vec<AnyInlet>,
493        outlets: Vec<AnyOutlet>,
494    ) -> StageSpec {
495        StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
496    }
497}
498
499#[derive(Clone, Debug)]
500pub struct MergePrioritized<T: 'static> {
501    weights: Vec<usize>,
502    _marker: PhantomData<fn() -> T>,
503}
504
505impl<T: 'static> MergePrioritized<T> {
506    #[must_use]
507    pub fn new(weights: Vec<usize>) -> Self {
508        assert!(!weights.is_empty(), "prioritized merge must have inputs");
509        assert!(
510            weights.iter().all(|weight| *weight > 0),
511            "prioritized merge weights must be greater than zero"
512        );
513        Self {
514            weights,
515            _marker: PhantomData,
516        }
517    }
518}
519
520impl<T> GraphStage for MergePrioritized<T>
521where
522    T: Clone + Send + 'static,
523{
524    type Shape = FanInShape<T, T>;
525
526    fn name(&self) -> &str {
527        "MergePrioritized"
528    }
529
530    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
531        let inlets = (0..self.weights.len())
532            .map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
533            .collect();
534        FanInShape::new(
535            inlets,
536            allocator.outlet_arc(merge_prioritized_outlet_name()),
537        )
538    }
539
540    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
541        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
542    }
543
544    fn stage_spec_with_ports(
545        &self,
546        _shape: &Self::Shape,
547        inlets: Vec<AnyInlet>,
548        outlets: Vec<AnyOutlet>,
549    ) -> StageSpec {
550        StageSpec::merge_prioritized(
551            merge_prioritized_stage_name(),
552            inlets,
553            outlets,
554            self.weights.clone(),
555        )
556    }
557}
558
559#[derive(Clone, Debug)]
560pub struct Zip<Left: 'static, Right: 'static> {
561    _marker: PhantomData<fn() -> (Left, Right)>,
562}
563
564impl<Left: 'static, Right: 'static> Zip<Left, Right> {
565    #[must_use]
566    pub fn new() -> Self {
567        Self {
568            _marker: PhantomData,
569        }
570    }
571}
572
573impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
574    fn default() -> Self {
575        Self::new()
576    }
577}
578
579impl<Left, Right> GraphStage for Zip<Left, Right>
580where
581    Left: Clone + Send + 'static,
582    Right: Clone + Send + 'static,
583{
584    type Shape = ZipShape<Left, Right>;
585
586    fn name(&self) -> &str {
587        "Zip"
588    }
589
590    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
591        let first_id = next_port_id_block(3);
592        ZipShape::new(
593            Inlet::with_arc_name(first_id, zip_in0_name()),
594            Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
595            Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
596        )
597    }
598
599    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
600        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
601    }
602
603    fn stage_spec_with_ports(
604        &self,
605        _shape: &Self::Shape,
606        inlets: Vec<AnyInlet>,
607        outlets: Vec<AnyOutlet>,
608    ) -> StageSpec {
609        let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
610            let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
611            let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
612            Ok(datum((left, right)))
613        });
614        StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
615    }
616}
617
618#[derive(Clone, Debug)]
619pub struct MergeSorted<T: 'static> {
620    _marker: PhantomData<fn() -> T>,
621}
622
623impl<T: 'static> MergeSorted<T> {
624    #[must_use]
625    pub fn new() -> Self {
626        Self {
627            _marker: PhantomData,
628        }
629    }
630}
631
632impl<T: 'static> Default for MergeSorted<T> {
633    fn default() -> Self {
634        Self::new()
635    }
636}
637
638impl<T> GraphStage for MergeSorted<T>
639where
640    T: Clone + Ord + Send + 'static,
641{
642    type Shape = FanInShape<T, T>;
643
644    fn name(&self) -> &str {
645        "MergeSorted"
646    }
647
648    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
649        let inlets = vec![
650            allocator.inlet("MergeSorted.in0"),
651            allocator.inlet("MergeSorted.in1"),
652        ];
653        FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
654    }
655
656    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
657        let compare = Arc::new(
658            move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
659                let a_t: &T = a
660                    .as_any_ref()
661                    .downcast_ref::<T>()
662                    .expect("merge-sorted compare: wrong element type");
663                let b_t: &T = b
664                    .as_any_ref()
665                    .downcast_ref::<T>()
666                    .expect("merge-sorted compare: wrong element type");
667                a_t.cmp(b_t)
668            },
669        );
670        StageSpec::merge_sorted(
671            Arc::from(self.name()),
672            shape.inlets(),
673            shape.outlets(),
674            compare,
675        )
676    }
677
678    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
679        struct State<T> {
680            left: VecDeque<T>,
681            right: VecDeque<T>,
682            left_closed: bool,
683            right_closed: bool,
684            pending: VecDeque<T>,
685        }
686
687        impl<T> Default for State<T> {
688            fn default() -> Self {
689                Self {
690                    left: VecDeque::new(),
691                    right: VecDeque::new(),
692                    left_closed: false,
693                    right_closed: false,
694                    pending: VecDeque::new(),
695                }
696            }
697        }
698
699        fn maybe_queue_output<T>(state: &mut State<T>) -> bool
700        where
701            T: Clone + Ord,
702        {
703            let next = match (state.left.front(), state.right.front()) {
704                (Some(left), Some(right)) => {
705                    if left <= right {
706                        state.left.pop_front()
707                    } else {
708                        state.right.pop_front()
709                    }
710                }
711                (Some(_), None) if state.right_closed => state.left.pop_front(),
712                (None, Some(_)) if state.left_closed => state.right.pop_front(),
713                _ => None,
714            };
715            if let Some(value) = next {
716                state.pending.push_back(value);
717                true
718            } else {
719                false
720            }
721        }
722
723        fn maybe_complete<T>(
724            logic: &mut GraphStageLogic,
725            outlet: &Outlet<T>,
726            state: &State<T>,
727        ) -> StreamResult<()>
728        where
729            T: Clone + Send + 'static,
730        {
731            if state.left_closed
732                && state.right_closed
733                && state.left.is_empty()
734                && state.right.is_empty()
735                && state.pending.is_empty()
736                && !logic.is_closed(outlet)
737            {
738                logic.complete(outlet)?;
739            }
740            Ok(())
741        }
742
743        fn maybe_pull<T>(
744            logic: &mut GraphStageLogic,
745            left: &Inlet<T>,
746            right: &Inlet<T>,
747            state: &State<T>,
748        ) -> StreamResult<()>
749        where
750            T: Clone + Send + 'static,
751        {
752            if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
753                logic.pull(left)?;
754            }
755            if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
756                logic.pull(right)?;
757            }
758            Ok(())
759        }
760
761        fn maybe_drain<T>(
762            logic: &mut GraphStageLogic,
763            outlet: &Outlet<T>,
764            state: &Arc<Mutex<State<T>>>,
765        ) -> StreamResult<()>
766        where
767            T: Clone + Ord + Send + 'static,
768        {
769            let next = if logic.is_available(outlet) {
770                state
771                    .lock()
772                    .expect("merge-sorted state poisoned")
773                    .pending
774                    .pop_front()
775            } else {
776                None
777            };
778            if let Some(value) = next {
779                logic.push(outlet, value)?;
780            }
781            Ok(())
782        }
783
784        struct In<T: 'static> {
785            inlet_id: PortId,
786            left: Inlet<T>,
787            right: Inlet<T>,
788            outlet: Outlet<T>,
789            state: Arc<Mutex<State<T>>>,
790        }
791
792        impl<T> InHandler for In<T>
793        where
794            T: Clone + Ord + Send + 'static,
795        {
796            fn on_push(
797                &mut self,
798                logic: &mut GraphStageLogic,
799                _inlet: AnyInlet,
800            ) -> StreamResult<()> {
801                let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
802                    downcast_datum(value, "grab", || {
803                        format!("inlet#{}", self.inlet_id.as_usize())
804                    })
805                })?;
806                {
807                    let mut state = self.state.lock().expect("merge-sorted state poisoned");
808                    if self.inlet_id == self.left.id() {
809                        state.left.push_back(value);
810                    } else {
811                        state.right.push_back(value);
812                    }
813                    while maybe_queue_output(&mut state) {}
814                }
815                maybe_drain(logic, &self.outlet, &self.state)?;
816                let state = self.state.lock().expect("merge-sorted state poisoned");
817                maybe_complete(logic, &self.outlet, &state)?;
818                maybe_pull(logic, &self.left, &self.right, &state)
819            }
820
821            fn on_upstream_finish(
822                &mut self,
823                logic: &mut GraphStageLogic,
824                _inlet: AnyInlet,
825            ) -> StreamResult<()> {
826                {
827                    let mut state = self.state.lock().expect("merge-sorted state poisoned");
828                    if self.inlet_id == self.left.id() {
829                        state.left_closed = true;
830                    } else {
831                        state.right_closed = true;
832                    }
833                    while maybe_queue_output(&mut state) {}
834                }
835                maybe_drain(logic, &self.outlet, &self.state)?;
836                let state = self.state.lock().expect("merge-sorted state poisoned");
837                maybe_complete(logic, &self.outlet, &state)?;
838                maybe_pull(logic, &self.left, &self.right, &state)
839            }
840        }
841
842        struct Out<T: 'static> {
843            left: Inlet<T>,
844            right: Inlet<T>,
845            outlet: Outlet<T>,
846            state: Arc<Mutex<State<T>>>,
847        }
848
849        impl<T> OutHandler for Out<T>
850        where
851            T: Clone + Ord + Send + 'static,
852        {
853            fn on_pull(
854                &mut self,
855                logic: &mut GraphStageLogic,
856                _outlet: AnyOutlet,
857            ) -> StreamResult<()> {
858                maybe_drain(logic, &self.outlet, &self.state)?;
859                let state = self.state.lock().expect("merge-sorted state poisoned");
860                maybe_complete(logic, &self.outlet, &state)?;
861                maybe_pull(logic, &self.left, &self.right, &state)
862            }
863        }
864
865        let state = Arc::new(Mutex::new(State::<T>::default()));
866        let left = shape.inlet(0).expect("merge-sorted left inlet");
867        let right = shape.inlet(1).expect("merge-sorted right inlet");
868        let outlet = shape.outlet();
869        let mut logic = GraphStageLogic::new(shape);
870        logic
871            .set_handler(
872                &left,
873                Box::new(In {
874                    inlet_id: left.id(),
875                    left: left.clone(),
876                    right: right.clone(),
877                    outlet: outlet.clone(),
878                    state: Arc::clone(&state),
879                }),
880            )
881            .unwrap();
882        logic
883            .set_handler(
884                &right,
885                Box::new(In {
886                    inlet_id: right.id(),
887                    left: left.clone(),
888                    right: right.clone(),
889                    outlet: outlet.clone(),
890                    state: Arc::clone(&state),
891                }),
892            )
893            .unwrap();
894        logic
895            .set_out_handler(
896                &outlet.clone(),
897                Box::new(Out {
898                    left,
899                    right,
900                    outlet: outlet.clone(),
901                    state,
902                }),
903            )
904            .unwrap();
905        logic
906    }
907}
908
909#[derive(Clone)]
910pub struct MergeSequence<T: 'static> {
911    inputs: usize,
912    extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
913    _marker: PhantomData<fn() -> T>,
914}
915
916impl<T: 'static> fmt::Debug for MergeSequence<T> {
917    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
918        f.debug_struct("MergeSequence")
919            .field("inputs", &self.inputs)
920            .finish_non_exhaustive()
921    }
922}
923
924impl<T: 'static> MergeSequence<T> {
925    #[must_use]
926    pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
927    where
928        F: Fn(&T) -> u64 + Send + Sync + 'static,
929    {
930        assert!(
931            inputs > 1,
932            "merge sequence input count must be greater than one"
933        );
934        Self {
935            inputs,
936            extract_sequence: Arc::new(extract_sequence),
937            _marker: PhantomData,
938        }
939    }
940}
941
942impl<T> GraphStage for MergeSequence<T>
943where
944    T: Clone + Send + 'static,
945{
946    type Shape = FanInShape<T, T>;
947
948    fn name(&self) -> &str {
949        "MergeSequence"
950    }
951
952    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
953        let inlets = (0..self.inputs)
954            .map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
955            .collect();
956        FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
957    }
958
959    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
960        let extract = Arc::clone(&self.extract_sequence);
961        let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
962            let t: &T = dv
963                .as_any_ref()
964                .downcast_ref::<T>()
965                .expect("merge-sequence extract: wrong element type");
966            extract(t)
967        });
968        // Typed extractor: stored as `Arc<dyn Any + Send + Sync>` and
969        // down-cast at plan time to `Arc<dyn Fn(&T) -> u64 + Send + Sync>`.
970        let typed_extract_fn = Arc::clone(&self.extract_sequence);
971        let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
972        let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
973        StageSpec::merge_sequence(
974            Arc::from(self.name()),
975            shape.inlets(),
976            shape.outlets(),
977            self.inputs,
978            extract_sequence,
979            typed_extract,
980        )
981    }
982
983    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
984        #[derive(Clone)]
985        struct Pending<T> {
986            sequence: u64,
987            elem: T,
988        }
989
990        struct State<T> {
991            next_sequence: u64,
992            pending: Vec<Pending<T>>,
993            completed: usize,
994            pending_output: VecDeque<T>,
995        }
996
997        fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
998        where
999            T: Clone + Send + 'static,
1000        {
1001            while let Some(index) = state
1002                .pending
1003                .iter()
1004                .position(|item| item.sequence == state.next_sequence)
1005            {
1006                let item = state.pending.remove(index);
1007                if state
1008                    .pending
1009                    .iter()
1010                    .any(|other| other.sequence == state.next_sequence)
1011                {
1012                    return Err(StreamError::Failed(format!(
1013                        "duplicate sequence {} on merge sequence",
1014                        state.next_sequence
1015                    )));
1016                }
1017                state.pending_output.push_back(item.elem);
1018                state.next_sequence += 1;
1019            }
1020            Ok(())
1021        }
1022
1023        struct In<T: 'static> {
1024            inlet_id: PortId,
1025            inlet_index: usize,
1026            inlet: Inlet<T>,
1027            all_inlets: Vec<Inlet<T>>,
1028            outlet: Outlet<T>,
1029            extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1030            state: Arc<Mutex<State<T>>>,
1031        }
1032
1033        impl<T> InHandler for In<T>
1034        where
1035            T: Clone + Send + 'static,
1036        {
1037            fn on_push(
1038                &mut self,
1039                logic: &mut GraphStageLogic,
1040                _inlet: AnyInlet,
1041            ) -> StreamResult<()> {
1042                let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1043                    downcast_datum(value, "grab", || {
1044                        format!("inlet#{}", self.inlet_id.as_usize())
1045                    })
1046                })?;
1047                {
1048                    let mut state = self.state.lock().expect("merge-sequence state poisoned");
1049                    let sequence = (self.extract_sequence)(&elem);
1050                    if sequence < state.next_sequence {
1051                        return Err(StreamError::Failed(format!(
1052                            "sequence regression from {} to {} on port {}",
1053                            state.next_sequence, sequence, self.inlet_index
1054                        )));
1055                    }
1056                    state.pending.push(Pending { sequence, elem });
1057                    try_emit_pending(&mut state)?;
1058                }
1059                let next = if logic.is_available(&self.outlet) {
1060                    self.state
1061                        .lock()
1062                        .expect("merge-sequence state poisoned")
1063                        .pending_output
1064                        .pop_front()
1065                } else {
1066                    None
1067                };
1068                if let Some(value) = next {
1069                    logic.push(&self.outlet, value)?;
1070                }
1071                let state = self.state.lock().expect("merge-sequence state poisoned");
1072                if state.completed == self.all_inlets.len()
1073                    && state.pending.is_empty()
1074                    && state.pending_output.is_empty()
1075                {
1076                    logic.complete(&self.outlet)?;
1077                } else if logic.is_available(&self.outlet)
1078                    && state.pending_output.is_empty()
1079                    && state.pending.len() + state.completed == self.all_inlets.len()
1080                {
1081                    return Err(StreamError::Failed(format!(
1082                        "expected sequence {}, but all input ports have pushed or are complete",
1083                        state.next_sequence
1084                    )));
1085                }
1086                if !logic.has_been_pulled(&self.inlet) {
1087                    logic.pull(&self.inlet)?;
1088                }
1089                Ok(())
1090            }
1091
1092            fn on_upstream_finish(
1093                &mut self,
1094                logic: &mut GraphStageLogic,
1095                _inlet: AnyInlet,
1096            ) -> StreamResult<()> {
1097                {
1098                    let mut state = self.state.lock().expect("merge-sequence state poisoned");
1099                    state.completed += 1;
1100                }
1101                let state = self.state.lock().expect("merge-sequence state poisoned");
1102                if state.completed == self.all_inlets.len()
1103                    && state.pending.is_empty()
1104                    && state.pending_output.is_empty()
1105                {
1106                    logic.complete(&self.outlet)?;
1107                } else if logic.is_available(&self.outlet)
1108                    && state.pending_output.is_empty()
1109                    && state.pending.len() + state.completed == self.all_inlets.len()
1110                {
1111                    return Err(StreamError::Failed(format!(
1112                        "expected sequence {}, but all input ports have pushed or are complete",
1113                        state.next_sequence
1114                    )));
1115                }
1116                Ok(())
1117            }
1118        }
1119
1120        struct Out<T: 'static> {
1121            inlets: Vec<Inlet<T>>,
1122            outlet: Outlet<T>,
1123            state: Arc<Mutex<State<T>>>,
1124        }
1125
1126        impl<T> OutHandler for Out<T>
1127        where
1128            T: Clone + Send + 'static,
1129        {
1130            fn on_pull(
1131                &mut self,
1132                logic: &mut GraphStageLogic,
1133                _outlet: AnyOutlet,
1134            ) -> StreamResult<()> {
1135                let next = self
1136                    .state
1137                    .lock()
1138                    .expect("merge-sequence state poisoned")
1139                    .pending_output
1140                    .pop_front();
1141                if let Some(value) = next {
1142                    logic.push(&self.outlet, value)?;
1143                } else {
1144                    let state = self.state.lock().expect("merge-sequence state poisoned");
1145                    if state.completed == self.inlets.len() && state.pending.is_empty() {
1146                        logic.complete(&self.outlet)?;
1147                    } else if state.pending.len() + state.completed == self.inlets.len() {
1148                        return Err(StreamError::Failed(format!(
1149                            "expected sequence {}, but all input ports have pushed or are complete",
1150                            state.next_sequence
1151                        )));
1152                    }
1153                }
1154                for inlet in &self.inlets {
1155                    if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1156                        logic.pull(inlet)?;
1157                    }
1158                }
1159                Ok(())
1160            }
1161        }
1162
1163        let inlets = shape.inlets_vec();
1164        let outlet = shape.outlet();
1165        let state = Arc::new(Mutex::new(State {
1166            next_sequence: 0,
1167            pending: Vec::new(),
1168            completed: 0,
1169            pending_output: VecDeque::new(),
1170        }));
1171        let mut logic = GraphStageLogic::new(shape);
1172        for (index, inlet) in inlets.iter().cloned().enumerate() {
1173            logic
1174                .set_handler(
1175                    &inlet.clone(),
1176                    Box::new(In {
1177                        inlet_id: inlet.id(),
1178                        inlet_index: index,
1179                        inlet: inlet.clone(),
1180                        all_inlets: inlets.clone(),
1181                        outlet: outlet.clone(),
1182                        extract_sequence: Arc::clone(&self.extract_sequence),
1183                        state: Arc::clone(&state),
1184                    }),
1185                )
1186                .unwrap();
1187        }
1188        logic
1189            .set_out_handler(
1190                &outlet.clone(),
1191                Box::new(Out {
1192                    inlets,
1193                    outlet: outlet.clone(),
1194                    state,
1195                }),
1196            )
1197            .unwrap();
1198        logic
1199    }
1200}
1201
1202#[derive(Clone, Debug)]
1203pub struct MergeLatest<T: 'static> {
1204    inputs: usize,
1205    eager_complete: bool,
1206    _marker: PhantomData<fn() -> T>,
1207}
1208
1209impl<T: 'static> MergeLatest<T> {
1210    #[must_use]
1211    pub fn new(inputs: usize, eager_complete: bool) -> Self {
1212        assert!(
1213            inputs > 0,
1214            "merge-latest input count must be greater than zero"
1215        );
1216        Self {
1217            inputs,
1218            eager_complete,
1219            _marker: PhantomData,
1220        }
1221    }
1222}
1223
1224impl<T> GraphStage for MergeLatest<T>
1225where
1226    T: Clone + Send + 'static,
1227{
1228    type Shape = FanInShape<T, Vec<T>>;
1229
1230    fn name(&self) -> &str {
1231        "MergeLatest"
1232    }
1233
1234    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1235        let inlets = (0..self.inputs)
1236            .map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
1237            .collect();
1238        FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
1239    }
1240
1241    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1242        let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
1243            let snapshot: Vec<T> = values
1244                .iter()
1245                .map(|dv| {
1246                    dv.as_any_ref()
1247                        .downcast_ref::<T>()
1248                        .cloned()
1249                        .expect("merge-latest snapshot: wrong element type")
1250                })
1251                .collect();
1252            datum(snapshot)
1253        });
1254        // Typed snapshot: builds Vec<T> directly from &[Option<T>] without boxing.
1255        // The closure type is complex; suppress the lint for this local alias.
1256        #[allow(clippy::type_complexity)]
1257        let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
1258            Arc::new(move |slots: &[Option<T>]| {
1259                slots
1260                    .iter()
1261                    .map(|s| {
1262                        s.clone()
1263                            .expect("merge-latest typed snapshot: slot is None")
1264                    })
1265                    .collect()
1266            });
1267        let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
1268        StageSpec::merge_latest(
1269            Arc::from(self.name()),
1270            shape.inlets(),
1271            shape.outlets(),
1272            self.inputs,
1273            self.eager_complete,
1274            build_snapshot,
1275            typed_snapshot,
1276        )
1277    }
1278
1279    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1280        struct State<T> {
1281            latest: Vec<Option<T>>,
1282            seen: usize,
1283            completed: usize,
1284            pending: VecDeque<Vec<T>>,
1285            eager_complete: bool,
1286        }
1287
1288        struct In<T: 'static> {
1289            inlet_id: PortId,
1290            inlet_index: usize,
1291            inlet: Inlet<T>,
1292            all_inlets: Vec<Inlet<T>>,
1293            outlet: Outlet<Vec<T>>,
1294            state: Arc<Mutex<State<T>>>,
1295        }
1296
1297        impl<T> InHandler for In<T>
1298        where
1299            T: Clone + Send + 'static,
1300        {
1301            fn on_push(
1302                &mut self,
1303                logic: &mut GraphStageLogic,
1304                _inlet: AnyInlet,
1305            ) -> StreamResult<()> {
1306                let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1307                    downcast_datum(value, "grab", || {
1308                        format!("inlet#{}", self.inlet_id.as_usize())
1309                    })
1310                })?;
1311                {
1312                    let mut state = self.state.lock().expect("merge-latest state poisoned");
1313                    if state.latest[self.inlet_index].is_none() {
1314                        state.seen += 1;
1315                    }
1316                    state.latest[self.inlet_index] = Some(elem);
1317                    if state.seen == state.latest.len() {
1318                        let snapshot = state
1319                            .latest
1320                            .iter()
1321                            .map(|item| item.clone().expect("merge-latest seen"))
1322                            .collect();
1323                        state.pending.push_back(snapshot);
1324                    }
1325                }
1326                let next = if logic.is_available(&self.outlet) {
1327                    self.state
1328                        .lock()
1329                        .expect("merge-latest state poisoned")
1330                        .pending
1331                        .pop_front()
1332                } else {
1333                    None
1334                };
1335                if let Some(value) = next {
1336                    logic.push(&self.outlet, value)?;
1337                }
1338                if !logic.has_been_pulled(&self.inlet) {
1339                    logic.pull(&self.inlet)?;
1340                }
1341                Ok(())
1342            }
1343
1344            fn on_upstream_finish(
1345                &mut self,
1346                logic: &mut GraphStageLogic,
1347                _inlet: AnyInlet,
1348            ) -> StreamResult<()> {
1349                let state = {
1350                    let mut state = self.state.lock().expect("merge-latest state poisoned");
1351                    state.completed += 1;
1352                    (
1353                        state.completed == self.all_inlets.len(),
1354                        state.eager_complete,
1355                        state.pending.is_empty(),
1356                    )
1357                };
1358                if state.0 || (state.1 && state.2) {
1359                    logic.complete(&self.outlet)?;
1360                }
1361                Ok(())
1362            }
1363        }
1364
1365        struct Out<T: 'static> {
1366            inlets: Vec<Inlet<T>>,
1367            outlet: Outlet<Vec<T>>,
1368            state: Arc<Mutex<State<T>>>,
1369        }
1370
1371        impl<T> OutHandler for Out<T>
1372        where
1373            T: Clone + Send + 'static,
1374        {
1375            fn on_pull(
1376                &mut self,
1377                logic: &mut GraphStageLogic,
1378                _outlet: AnyOutlet,
1379            ) -> StreamResult<()> {
1380                let next = self
1381                    .state
1382                    .lock()
1383                    .expect("merge-latest state poisoned")
1384                    .pending
1385                    .pop_front();
1386                if let Some(value) = next {
1387                    logic.push(&self.outlet, value)?;
1388                } else {
1389                    let state = self.state.lock().expect("merge-latest state poisoned");
1390                    if state.completed == self.inlets.len()
1391                        || (state.eager_complete && state.completed > 0)
1392                    {
1393                        logic.complete(&self.outlet)?;
1394                    }
1395                }
1396                for inlet in &self.inlets {
1397                    if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1398                        logic.pull(inlet)?;
1399                    }
1400                }
1401                Ok(())
1402            }
1403        }
1404
1405        let inlets = shape.inlets_vec();
1406        let outlet = shape.outlet();
1407        let state = Arc::new(Mutex::new(State {
1408            latest: vec![None; inlets.len()],
1409            seen: 0,
1410            completed: 0,
1411            pending: VecDeque::new(),
1412            eager_complete: self.eager_complete,
1413        }));
1414        let mut logic = GraphStageLogic::new(shape);
1415        for (index, inlet) in inlets.iter().cloned().enumerate() {
1416            logic
1417                .set_handler(
1418                    &inlet.clone(),
1419                    Box::new(In {
1420                        inlet_id: inlet.id(),
1421                        inlet_index: index,
1422                        inlet: inlet.clone(),
1423                        all_inlets: inlets.clone(),
1424                        outlet: outlet.clone(),
1425                        state: Arc::clone(&state),
1426                    }),
1427                )
1428                .unwrap();
1429        }
1430        logic
1431            .set_out_handler(
1432                &outlet.clone(),
1433                Box::new(Out {
1434                    inlets,
1435                    outlet: outlet.clone(),
1436                    state,
1437                }),
1438            )
1439            .unwrap();
1440        logic
1441    }
1442}
1443
1444#[derive(Clone)]
1445pub struct Partition<T: 'static> {
1446    outputs: usize,
1447    partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1448    eager_cancel: bool,
1449    _marker: PhantomData<fn() -> T>,
1450}
1451
1452impl<T: 'static> fmt::Debug for Partition<T> {
1453    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1454        f.debug_struct("Partition")
1455            .field("outputs", &self.outputs)
1456            .field("eager_cancel", &self.eager_cancel)
1457            .finish_non_exhaustive()
1458    }
1459}
1460
1461impl<T: 'static> Partition<T> {
1462    #[must_use]
1463    pub fn new<F>(outputs: usize, partitioner: F) -> Self
1464    where
1465        F: Fn(&T) -> usize + Send + Sync + 'static,
1466    {
1467        Self::new_with_eager_cancel(outputs, partitioner, false)
1468    }
1469
1470    #[must_use]
1471    pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
1472    where
1473        F: Fn(&T) -> usize + Send + Sync + 'static,
1474    {
1475        assert!(
1476            outputs > 0,
1477            "partition output count must be greater than zero"
1478        );
1479        Self {
1480            outputs,
1481            partitioner: Arc::new(partitioner),
1482            eager_cancel,
1483            _marker: PhantomData,
1484        }
1485    }
1486}
1487
1488impl<T> GraphStage for Partition<T>
1489where
1490    T: Clone + Send + 'static,
1491{
1492    type Shape = FanOutShape<T, T>;
1493
1494    fn name(&self) -> &str {
1495        "Partition"
1496    }
1497
1498    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1499        let inlet = allocator.inlet("Partition.in");
1500        let outlets = (0..self.outputs)
1501            .map(|index| allocator.outlet(format!("Partition.out{index}")))
1502            .collect();
1503        FanOutShape::new(inlet, outlets)
1504    }
1505
1506    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1507        let partitioner_clone = Arc::clone(&self.partitioner);
1508        let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
1509            let t: &T = dv
1510                .as_any_ref()
1511                .downcast_ref::<T>()
1512                .expect("partition: wrong element type");
1513            partitioner_clone(t)
1514        });
1515        StageSpec::partition(
1516            Arc::from(self.name()),
1517            shape.inlets(),
1518            shape.outlets(),
1519            self.outputs,
1520            partitioner,
1521            self.eager_cancel,
1522        )
1523    }
1524
1525    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1526        struct State<T> {
1527            pending: Option<(usize, T)>,
1528            upstream_closed: bool,
1529            live_outlets: usize,
1530            cancelled: Vec<bool>,
1531            eager_cancel: bool,
1532        }
1533
1534        fn any_live_demand<T>(
1535            logic: &GraphStageLogic,
1536            outlets: &[Outlet<T>],
1537            cancelled: &[bool],
1538        ) -> bool
1539        where
1540            T: Clone + Send + 'static,
1541        {
1542            outlets
1543                .iter()
1544                .enumerate()
1545                .any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
1546        }
1547
1548        struct In<T: 'static> {
1549            inlet_id: PortId,
1550            inlet: Inlet<T>,
1551            outlets: Vec<Outlet<T>>,
1552            partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1553            state: Arc<Mutex<State<T>>>,
1554        }
1555
1556        impl<T> InHandler for In<T>
1557        where
1558            T: Clone + Send + 'static,
1559        {
1560            fn on_push(
1561                &mut self,
1562                logic: &mut GraphStageLogic,
1563                _inlet: AnyInlet,
1564            ) -> StreamResult<()> {
1565                let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1566                    downcast_datum(value, "grab", || {
1567                        format!("inlet#{}", self.inlet_id.as_usize())
1568                    })
1569                })?;
1570                let idx = (self.partitioner)(&item);
1571                if idx >= self.outlets.len() {
1572                    return Err(StreamError::Failed(format!(
1573                        "partitioner returned out-of-bounds index {idx} for {} outputs",
1574                        self.outlets.len()
1575                    )));
1576                }
1577                let mut pull_again = false;
1578                {
1579                    let mut state = self.state.lock().expect("partition state poisoned");
1580                    if state.cancelled[idx] {
1581                        pull_again = !state.upstream_closed
1582                            && any_live_demand(logic, &self.outlets, &state.cancelled);
1583                    } else if logic.is_available(&self.outlets[idx]) {
1584                        logic.push(&self.outlets[idx], item)?;
1585                        pull_again = !state.upstream_closed
1586                            && any_live_demand(logic, &self.outlets, &state.cancelled);
1587                    } else {
1588                        state.pending = Some((idx, item));
1589                    }
1590                }
1591                if pull_again && !logic.has_been_pulled(&self.inlet) {
1592                    logic.pull(&self.inlet)?;
1593                }
1594                Ok(())
1595            }
1596
1597            fn on_upstream_finish(
1598                &mut self,
1599                logic: &mut GraphStageLogic,
1600                _inlet: AnyInlet,
1601            ) -> StreamResult<()> {
1602                let complete_now = {
1603                    let mut state = self.state.lock().expect("partition state poisoned");
1604                    state.upstream_closed = true;
1605                    state.pending.is_none()
1606                };
1607                if complete_now {
1608                    for outlet in &self.outlets {
1609                        if !logic.is_closed(outlet) {
1610                            logic.complete(outlet)?;
1611                        }
1612                    }
1613                }
1614                Ok(())
1615            }
1616        }
1617
1618        struct Out<T: 'static> {
1619            index: usize,
1620            inlet: Inlet<T>,
1621            outlets: Vec<Outlet<T>>,
1622            state: Arc<Mutex<State<T>>>,
1623        }
1624
1625        impl<T> OutHandler for Out<T>
1626        where
1627            T: Clone + Send + 'static,
1628        {
1629            fn on_pull(
1630                &mut self,
1631                logic: &mut GraphStageLogic,
1632                _outlet: AnyOutlet,
1633            ) -> StreamResult<()> {
1634                let mut complete_now = false;
1635                let pending = {
1636                    let mut state = self.state.lock().expect("partition state poisoned");
1637                    if let Some((idx, _)) = &state.pending
1638                        && *idx == self.index
1639                    {
1640                        state.pending.take()
1641                    } else {
1642                        None
1643                    }
1644                };
1645                if let Some((_, item)) = pending {
1646                    logic.push(&self.outlets[self.index], item)?;
1647                    let state = self.state.lock().expect("partition state poisoned");
1648                    if state.upstream_closed {
1649                        complete_now = true;
1650                    } else if any_live_demand(logic, &self.outlets, &state.cancelled)
1651                        && !logic.has_been_pulled(&self.inlet)
1652                    {
1653                        logic.pull(&self.inlet)?;
1654                    }
1655                } else {
1656                    let state = self.state.lock().expect("partition state poisoned");
1657                    if state.upstream_closed {
1658                        complete_now = true;
1659                    } else if any_live_demand(logic, &self.outlets, &state.cancelled)
1660                        && !logic.has_been_pulled(&self.inlet)
1661                    {
1662                        logic.pull(&self.inlet)?;
1663                    }
1664                }
1665                if complete_now {
1666                    for outlet in &self.outlets {
1667                        if !logic.is_closed(outlet) {
1668                            logic.complete(outlet)?;
1669                        }
1670                    }
1671                }
1672                Ok(())
1673            }
1674
1675            fn on_downstream_finish(
1676                &mut self,
1677                logic: &mut GraphStageLogic,
1678                _outlet: AnyOutlet,
1679            ) -> StreamResult<()> {
1680                let (cancel_stage, clear_pending) = {
1681                    let mut state = self.state.lock().expect("partition state poisoned");
1682                    if state.cancelled[self.index] {
1683                        return Ok(());
1684                    }
1685                    state.cancelled[self.index] = true;
1686                    state.live_outlets -= 1;
1687                    let clear_pending = state
1688                        .pending
1689                        .as_ref()
1690                        .is_some_and(|(idx, _)| *idx == self.index);
1691                    let cancel_stage = state.eager_cancel || state.live_outlets == 0;
1692                    if clear_pending {
1693                        state.pending = None;
1694                    }
1695                    (cancel_stage, clear_pending)
1696                };
1697                if cancel_stage {
1698                    logic.complete_stage()?;
1699                } else if clear_pending
1700                    && !logic.has_been_pulled(&self.inlet)
1701                    && !logic.is_closed(&self.inlet)
1702                {
1703                    let state = self.state.lock().expect("partition state poisoned");
1704                    if any_live_demand(logic, &self.outlets, &state.cancelled) {
1705                        logic.pull(&self.inlet)?;
1706                    }
1707                }
1708                Ok(())
1709            }
1710        }
1711
1712        let inlet = shape.inlet();
1713        let outlets = shape.outlets_vec();
1714        let state = Arc::new(Mutex::new(State {
1715            pending: None,
1716            upstream_closed: false,
1717            live_outlets: outlets.len(),
1718            cancelled: vec![false; outlets.len()],
1719            eager_cancel: self.eager_cancel,
1720        }));
1721        let mut logic = GraphStageLogic::new(shape);
1722        logic
1723            .set_handler(
1724                &inlet,
1725                Box::new(In {
1726                    inlet_id: inlet.id(),
1727                    inlet: inlet.clone(),
1728                    outlets: outlets.clone(),
1729                    partitioner: Arc::clone(&self.partitioner),
1730                    state: Arc::clone(&state),
1731                }),
1732            )
1733            .unwrap();
1734        for (index, outlet) in outlets.iter().cloned().enumerate() {
1735            logic
1736                .set_out_handler(
1737                    &outlet,
1738                    Box::new(Out {
1739                        index,
1740                        inlet: inlet.clone(),
1741                        outlets: outlets.clone(),
1742                        state: Arc::clone(&state),
1743                    }),
1744                )
1745                .unwrap();
1746        }
1747        logic
1748    }
1749}
1750
1751#[derive(Clone, Debug)]
1752pub struct Unzip<A: 'static, B: 'static> {
1753    _marker: PhantomData<fn() -> (A, B)>,
1754}
1755
1756impl<A: 'static, B: 'static> Unzip<A, B> {
1757    #[must_use]
1758    pub fn new() -> Self {
1759        Self {
1760            _marker: PhantomData,
1761        }
1762    }
1763}
1764
1765impl<A: 'static, B: 'static> Default for Unzip<A, B> {
1766    fn default() -> Self {
1767        Self::new()
1768    }
1769}
1770
1771impl<A, B> GraphStage for Unzip<A, B>
1772where
1773    A: Clone + Send + 'static,
1774    B: Clone + Send + 'static,
1775{
1776    type Shape = FanOutShape2<(A, B), A, B>;
1777
1778    fn name(&self) -> &str {
1779        "Unzip"
1780    }
1781
1782    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1783        FanOutShape2::new(
1784            allocator.inlet("Unzip.in"),
1785            allocator.outlet("Unzip.out0"),
1786            allocator.outlet("Unzip.out1"),
1787        )
1788    }
1789
1790    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1791        let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
1792            let pair: (A, B) =
1793                downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
1794            (datum(pair.0), datum(pair.1))
1795        });
1796        // Typed split: `|(a, b): (A, B)| -> (A, B)`.  Stored as an opaque
1797        // `Arc<StageTypedUnzipFn>` and down-cast at plan time.
1798        // The `Arc<dyn Fn((A, B)) -> (A, B)>` wrapper is intentionally verbose;
1799        // suppress the type_complexity lint for this one binding.
1800        #[allow(clippy::type_complexity)]
1801        let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
1802            Arc::new(|pair: (A, B)| pair);
1803        let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
1804        StageSpec::unzip(
1805            Arc::from(self.name()),
1806            shape.inlets(),
1807            shape.outlets(),
1808            split,
1809            typed_split,
1810        )
1811    }
1812
1813    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1814        UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
1815    }
1816}
1817
1818#[derive(Clone)]
1819pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
1820    split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
1821    _marker: PhantomData<fn(In) -> (Out0, Out1)>,
1822}
1823
1824impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
1825    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1826        f.debug_struct("UnzipWith").finish_non_exhaustive()
1827    }
1828}
1829
1830impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
1831    #[must_use]
1832    pub fn new<F>(split: F) -> Self
1833    where
1834        F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
1835    {
1836        Self {
1837            split: Arc::new(split),
1838            _marker: PhantomData,
1839        }
1840    }
1841}
1842
1843impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
1844where
1845    In: Clone + Send + 'static,
1846    Out0: Clone + Send + 'static,
1847    Out1: Clone + Send + 'static,
1848{
1849    type Shape = FanOutShape2<In, Out0, Out1>;
1850
1851    fn name(&self) -> &str {
1852        "UnzipWith"
1853    }
1854
1855    fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1856        FanOutShape2::new(
1857            allocator.inlet("UnzipWith.in"),
1858            allocator.outlet("UnzipWith.out0"),
1859            allocator.outlet("UnzipWith.out1"),
1860        )
1861    }
1862
1863    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1864        let split_fn = Arc::clone(&self.split);
1865        let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
1866            let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
1867                .expect("unzip-with: wrong element type");
1868            let (out0, out1) = split_fn(value);
1869            (datum(out0), datum(out1))
1870        });
1871        // Typed split: `Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>`.
1872        // Stored as opaque `Arc<StageTypedUnzipFn>` and down-cast at plan time.
1873        let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
1874        let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
1875        StageSpec::unzip(
1876            Arc::from(self.name()),
1877            shape.inlets(),
1878            shape.outlets(),
1879            split,
1880            typed_split,
1881        )
1882    }
1883
1884    fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1885        struct State {
1886            left_open: bool,
1887            right_open: bool,
1888            upstream_closed: bool,
1889        }
1890
1891        struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
1892            inlet_id: PortId,
1893            inlet: Inlet<In>,
1894            out0: Outlet<Out0>,
1895            out1: Outlet<Out1>,
1896            split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
1897            state: Arc<Mutex<State>>,
1898        }
1899
1900        impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
1901        where
1902            In: Clone + Send + 'static,
1903            Out0: Clone + Send + 'static,
1904            Out1: Clone + Send + 'static,
1905        {
1906            fn on_push(
1907                &mut self,
1908                logic: &mut GraphStageLogic,
1909                _inlet: AnyInlet,
1910            ) -> StreamResult<()> {
1911                let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
1912                    downcast_datum(value, "grab", || {
1913                        format!("inlet#{}", self.inlet_id.as_usize())
1914                    })
1915                })?;
1916                let (left, right) = (self.split)(value);
1917                let state = self.state.lock().expect("unzip-with state poisoned");
1918                if state.left_open {
1919                    logic.push(&self.out0, left)?;
1920                }
1921                if state.right_open {
1922                    logic.push(&self.out1, right)?;
1923                }
1924                drop(state);
1925                let state = self.state.lock().expect("unzip-with state poisoned");
1926                let left_ready = !state.left_open || logic.is_available(&self.out0);
1927                let right_ready = !state.right_open || logic.is_available(&self.out1);
1928                if (state.left_open || state.right_open)
1929                    && left_ready
1930                    && right_ready
1931                    && !logic.has_been_pulled(&self.inlet)
1932                {
1933                    logic.pull(&self.inlet)?;
1934                }
1935                Ok(())
1936            }
1937
1938            fn on_upstream_finish(
1939                &mut self,
1940                logic: &mut GraphStageLogic,
1941                _inlet: AnyInlet,
1942            ) -> StreamResult<()> {
1943                self.state
1944                    .lock()
1945                    .expect("unzip-with state poisoned")
1946                    .upstream_closed = true;
1947                if !logic.is_closed(&self.out0) {
1948                    logic.complete(&self.out0)?;
1949                }
1950                if !logic.is_closed(&self.out1) {
1951                    logic.complete(&self.out1)?;
1952                }
1953                Ok(())
1954            }
1955        }
1956
1957        struct Out<In: 'static, Out0: 'static, Out1: 'static> {
1958            is_left: bool,
1959            inlet: Inlet<In>,
1960            out0: Outlet<Out0>,
1961            out1: Outlet<Out1>,
1962            state: Arc<Mutex<State>>,
1963        }
1964
1965        impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
1966        where
1967            In: Clone + Send + 'static,
1968            Out0: Clone + Send + 'static,
1969            Out1: Clone + Send + 'static,
1970        {
1971            fn on_pull(
1972                &mut self,
1973                logic: &mut GraphStageLogic,
1974                _outlet: AnyOutlet,
1975            ) -> StreamResult<()> {
1976                let state = self.state.lock().expect("unzip-with state poisoned");
1977                let left_ready = !state.left_open || logic.is_available(&self.out0);
1978                let right_ready = !state.right_open || logic.is_available(&self.out1);
1979                if state.upstream_closed {
1980                    drop(state);
1981                    if !logic.is_closed(&self.out0) {
1982                        logic.complete(&self.out0)?;
1983                    }
1984                    if !logic.is_closed(&self.out1) {
1985                        logic.complete(&self.out1)?;
1986                    }
1987                } else if (state.left_open || state.right_open)
1988                    && left_ready
1989                    && right_ready
1990                    && !logic.has_been_pulled(&self.inlet)
1991                {
1992                    drop(state);
1993                    logic.pull(&self.inlet)?;
1994                }
1995                Ok(())
1996            }
1997
1998            fn on_downstream_finish(
1999                &mut self,
2000                logic: &mut GraphStageLogic,
2001                _outlet: AnyOutlet,
2002            ) -> StreamResult<()> {
2003                let mut state = self.state.lock().expect("unzip-with state poisoned");
2004                if self.is_left {
2005                    state.left_open = false;
2006                } else {
2007                    state.right_open = false;
2008                }
2009                if !state.left_open && !state.right_open {
2010                    logic.complete_stage()?;
2011                    return Ok(());
2012                }
2013                let left_ready = !state.left_open || logic.is_available(&self.out0);
2014                let right_ready = !state.right_open || logic.is_available(&self.out1);
2015                if !state.upstream_closed
2016                    && (state.left_open || state.right_open)
2017                    && left_ready
2018                    && right_ready
2019                    && !logic.has_been_pulled(&self.inlet)
2020                {
2021                    logic.pull(&self.inlet)?;
2022                }
2023                Ok(())
2024            }
2025        }
2026
2027        let inlet = shape.inlet();
2028        let out0 = shape.out0();
2029        let out1 = shape.out1();
2030        let state = Arc::new(Mutex::new(State {
2031            left_open: true,
2032            right_open: true,
2033            upstream_closed: false,
2034        }));
2035        let mut logic = GraphStageLogic::new(shape);
2036        logic
2037            .set_handler(
2038                &inlet,
2039                Box::new(InHandlerState {
2040                    inlet_id: inlet.id(),
2041                    inlet: inlet.clone(),
2042                    out0: out0.clone(),
2043                    out1: out1.clone(),
2044                    split: Arc::clone(&self.split),
2045                    state: Arc::clone(&state),
2046                }),
2047            )
2048            .unwrap();
2049        logic
2050            .set_out_handler(
2051                &out0,
2052                Box::new(Out {
2053                    is_left: true,
2054                    inlet: inlet.clone(),
2055                    out0: out0.clone(),
2056                    out1: out1.clone(),
2057                    state: Arc::clone(&state),
2058                }),
2059            )
2060            .unwrap();
2061        logic
2062            .set_out_handler(
2063                &out1.clone(),
2064                Box::new(Out {
2065                    is_left: false,
2066                    inlet: inlet.clone(),
2067                    out0: out0.clone(),
2068                    out1: out1.clone(),
2069                    state,
2070                }),
2071            )
2072            .unwrap();
2073        logic
2074    }
2075}
2076
2077#[derive(Clone, Debug)]
2078pub struct AsyncBoundary<T: 'static> {
2079    _marker: PhantomData<fn() -> T>,
2080}
2081
2082impl<T: 'static> AsyncBoundary<T> {
2083    #[must_use]
2084    pub fn new() -> Self {
2085        Self {
2086            _marker: PhantomData,
2087        }
2088    }
2089}
2090
2091impl<T: 'static> Default for AsyncBoundary<T> {
2092    fn default() -> Self {
2093        Self::new()
2094    }
2095}
2096
2097impl<T> GraphStage for AsyncBoundary<T>
2098where
2099    T: Clone + Send + 'static,
2100{
2101    type Shape = FlowShape<T, T>;
2102
2103    fn name(&self) -> &str {
2104        "AsyncBoundary"
2105    }
2106
2107    fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
2108        let first_id = next_port_id_block(2);
2109        FlowShape::new(
2110            Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
2111            Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
2112        )
2113    }
2114
2115    fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2116        self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
2117    }
2118
2119    fn stage_spec_with_ports(
2120        &self,
2121        _shape: &Self::Shape,
2122        inlets: Vec<AnyInlet>,
2123        outlets: Vec<AnyOutlet>,
2124    ) -> StageSpec {
2125        StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
2126    }
2127}