1use super::*;
2
3#[derive(Clone, Debug)]
4pub struct Identity<T: 'static> {
5 _marker: PhantomData<fn() -> T>,
6}
7
8impl<T: 'static> Identity<T> {
9 #[must_use]
10 pub fn new() -> Self {
11 Self {
12 _marker: PhantomData,
13 }
14 }
15}
16
17impl<T: 'static> Default for Identity<T> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<T> GraphStage for Identity<T>
24where
25 T: Clone + Send + 'static,
26{
27 type Shape = FlowShape<T, T>;
28
29 fn name(&self) -> &str {
30 "Identity"
31 }
32
33 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
34 let first_id = next_port_id_block(2);
35 FlowShape::new(
36 Inlet::with_arc_name(first_id, identity_inlet_name()),
37 Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
38 )
39 }
40
41 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
42 StageSpec::identity(shape.inlets(), shape.outlets())
43 }
44
45 fn stage_spec_with_ports(
46 &self,
47 _shape: &Self::Shape,
48 inlets: Vec<AnyInlet>,
49 outlets: Vec<AnyOutlet>,
50 ) -> StageSpec {
51 StageSpec::identity(inlets, outlets)
52 }
53}
54
55#[derive(Clone)]
56pub struct MapStage<In: 'static, Out: 'static> {
57 f: Arc<dyn Fn(In) -> Out + Send + Sync>,
58 _marker: PhantomData<fn(In) -> Out>,
59}
60
61impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("MapStage")
64 .field("name", &"Map")
65 .finish_non_exhaustive()
66 }
67}
68
69impl<In: 'static, Out: 'static> MapStage<In, Out> {
70 #[must_use]
71 pub fn new<F>(f: F) -> Self
72 where
73 F: Fn(In) -> Out + Send + Sync + 'static,
74 {
75 Self {
76 f: Arc::new(f),
77 _marker: PhantomData,
78 }
79 }
80}
81
82impl<In, Out> GraphStage for MapStage<In, Out>
83where
84 In: Clone + Send + 'static,
85 Out: Clone + Send + 'static,
86{
87 type Shape = FlowShape<In, Out>;
88
89 fn name(&self) -> &str {
90 "Map"
91 }
92
93 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
94 let first_id = next_port_id_block(2);
95 FlowShape::new(
96 Inlet::with_arc_name(first_id, map_inlet_name()),
97 Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
98 )
99 }
100
101 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
102 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
103 }
104
105 fn stage_spec_with_ports(
106 &self,
107 _shape: &Self::Shape,
108 inlets: Vec<AnyInlet>,
109 outlets: Vec<AnyOutlet>,
110 ) -> StageSpec {
111 let f = Arc::clone(&self.f);
112 let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
113 let mapper = Arc::new(move |value: DatumValue| {
114 let value: In = downcast_datum(value, "map", || "Map.in")?;
115 Ok(datum(f(value)))
116 });
117 StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
118 }
119}
120
121#[derive(Clone, Debug)]
122pub struct Broadcast<T: 'static> {
123 outputs: usize,
124 _marker: PhantomData<fn() -> T>,
125}
126
127impl<T: 'static> Broadcast<T> {
128 #[must_use]
129 pub fn new(outputs: usize) -> Self {
130 assert!(
131 outputs > 0,
132 "broadcast output count must be greater than zero"
133 );
134 Self {
135 outputs,
136 _marker: PhantomData,
137 }
138 }
139}
140
141impl<T> GraphStage for Broadcast<T>
142where
143 T: Clone + Send + 'static,
144{
145 type Shape = FanOutShape<T, T>;
146
147 fn name(&self) -> &str {
148 "Broadcast"
149 }
150
151 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
152 let inlet = allocator.inlet_arc(broadcast_inlet_name());
153 let outlets = (0..self.outputs)
154 .map(|index| allocator.outlet(format!("Broadcast.out{index}")))
155 .collect();
156 FanOutShape::new(inlet, outlets)
157 }
158
159 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
160 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
161 }
162
163 fn stage_spec_with_ports(
164 &self,
165 _shape: &Self::Shape,
166 inlets: Vec<AnyInlet>,
167 outlets: Vec<AnyOutlet>,
168 ) -> StageSpec {
169 StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
170 }
171}
172
173#[derive(Clone, Debug)]
174pub struct Balance<T: 'static> {
175 outputs: usize,
176 _marker: PhantomData<fn() -> T>,
177}
178
179impl<T: 'static> Balance<T> {
180 #[must_use]
181 pub fn new(outputs: usize) -> Self {
182 assert!(
183 outputs > 0,
184 "balance output count must be greater than zero"
185 );
186 Self {
187 outputs,
188 _marker: PhantomData,
189 }
190 }
191}
192
193impl<T> GraphStage for Balance<T>
194where
195 T: Clone + Send + 'static,
196{
197 type Shape = FanOutShape<T, T>;
198
199 fn name(&self) -> &str {
200 "Balance"
201 }
202
203 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
204 let inlet = allocator.inlet_arc(balance_inlet_name());
205 let outlets = (0..self.outputs)
206 .map(|index| allocator.outlet(format!("Balance.out{index}")))
207 .collect();
208 FanOutShape::new(inlet, outlets)
209 }
210
211 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
212 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
213 }
214
215 fn stage_spec_with_ports(
216 &self,
217 _shape: &Self::Shape,
218 inlets: Vec<AnyInlet>,
219 outlets: Vec<AnyOutlet>,
220 ) -> StageSpec {
221 StageSpec::balance(balance_stage_name(), inlets, outlets)
222 }
223}
224
225#[derive(Clone, Debug)]
226pub struct Merge<T: 'static> {
227 inputs: usize,
228 _marker: PhantomData<fn() -> T>,
229}
230
231impl<T: 'static> Merge<T> {
232 #[must_use]
233 pub fn new(inputs: usize) -> Self {
234 assert!(inputs > 0, "merge input count must be greater than zero");
235 Self {
236 inputs,
237 _marker: PhantomData,
238 }
239 }
240}
241
242impl<T> GraphStage for Merge<T>
243where
244 T: Clone + Send + 'static,
245{
246 type Shape = FanInShape<T, T>;
247
248 fn name(&self) -> &str {
249 "Merge"
250 }
251
252 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
253 let inlets = (0..self.inputs)
254 .map(|index| allocator.inlet(format!("Merge.in{index}")))
255 .collect();
256 FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
257 }
258
259 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
260 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
261 }
262
263 fn stage_spec_with_ports(
264 &self,
265 _shape: &Self::Shape,
266 inlets: Vec<AnyInlet>,
267 outlets: Vec<AnyOutlet>,
268 ) -> StageSpec {
269 StageSpec::merge(merge_stage_name(), inlets, outlets)
270 }
271}
272
273#[derive(Clone, Debug)]
274pub struct Concat<T: 'static> {
275 inputs: usize,
276 _marker: PhantomData<fn() -> T>,
277}
278
279impl<T: 'static> Concat<T> {
280 #[must_use]
281 pub fn new(inputs: usize) -> Self {
282 assert!(inputs > 1, "concat input count must be greater than one");
283 Self {
284 inputs,
285 _marker: PhantomData,
286 }
287 }
288}
289
290impl<T> GraphStage for Concat<T>
291where
292 T: Clone + Send + 'static,
293{
294 type Shape = FanInShape<T, T>;
295
296 fn name(&self) -> &str {
297 "Concat"
298 }
299
300 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
301 let inlets = (0..self.inputs)
302 .map(|index| allocator.inlet(format!("Concat.in{index}")))
303 .collect();
304 FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
305 }
306
307 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
308 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
309 }
310
311 fn stage_spec_with_ports(
312 &self,
313 _shape: &Self::Shape,
314 inlets: Vec<AnyInlet>,
315 outlets: Vec<AnyOutlet>,
316 ) -> StageSpec {
317 StageSpec::concat(concat_stage_name(), inlets, outlets)
318 }
319}
320
321#[derive(Clone, Debug)]
322pub struct OrElse<T: 'static> {
323 _marker: PhantomData<fn() -> T>,
324}
325
326impl<T: 'static> OrElse<T> {
327 #[must_use]
328 pub fn new() -> Self {
329 Self {
330 _marker: PhantomData,
331 }
332 }
333}
334
335impl<T: 'static> Default for OrElse<T> {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl<T> GraphStage for OrElse<T>
342where
343 T: Clone + Send + 'static,
344{
345 type Shape = FanInShape<T, T>;
346
347 fn name(&self) -> &str {
348 "OrElse"
349 }
350
351 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
352 let inlets = vec![
353 allocator.inlet_arc(or_else_primary_name()),
354 allocator.inlet_arc(or_else_secondary_name()),
355 ];
356 FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
357 }
358
359 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
360 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
361 }
362
363 fn stage_spec_with_ports(
364 &self,
365 _shape: &Self::Shape,
366 inlets: Vec<AnyInlet>,
367 outlets: Vec<AnyOutlet>,
368 ) -> StageSpec {
369 StageSpec::or_else(or_else_stage_name(), inlets, outlets)
370 }
371}
372
373#[derive(Clone, Debug)]
374pub struct Interleave<T: 'static> {
375 inputs: usize,
376 segment_size: usize,
377 eager_close: bool,
378 _marker: PhantomData<fn() -> T>,
379}
380
381impl<T: 'static> Interleave<T> {
382 #[must_use]
383 pub fn new(inputs: usize, segment_size: usize) -> Self {
384 Self::new_with_eager_close(inputs, segment_size, false)
385 }
386
387 #[must_use]
388 pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
389 assert!(
390 inputs > 1,
391 "interleave input count must be greater than one"
392 );
393 assert!(
394 segment_size > 0,
395 "interleave segment size must be greater than zero"
396 );
397 Self {
398 inputs,
399 segment_size,
400 eager_close,
401 _marker: PhantomData,
402 }
403 }
404}
405
406impl<T> GraphStage for Interleave<T>
407where
408 T: Clone + Send + 'static,
409{
410 type Shape = FanInShape<T, T>;
411
412 fn name(&self) -> &str {
413 "Interleave"
414 }
415
416 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
417 let inlets = (0..self.inputs)
418 .map(|index| allocator.inlet(format!("Interleave.in{index}")))
419 .collect();
420 FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
421 }
422
423 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
424 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
425 }
426
427 fn stage_spec_with_ports(
428 &self,
429 _shape: &Self::Shape,
430 inlets: Vec<AnyInlet>,
431 outlets: Vec<AnyOutlet>,
432 ) -> StageSpec {
433 StageSpec::interleave(
434 interleave_stage_name(),
435 inlets,
436 outlets,
437 self.segment_size,
438 self.eager_close,
439 )
440 }
441}
442
443#[derive(Clone, Debug)]
444pub struct MergePreferred<T: 'static> {
445 secondary_ports: usize,
446 _marker: PhantomData<fn() -> T>,
447}
448
449impl<T: 'static> MergePreferred<T> {
450 #[must_use]
451 pub fn new(secondary_ports: usize) -> Self {
452 assert!(
453 secondary_ports > 0,
454 "merge-preferred secondary input count must be greater than zero"
455 );
456 Self {
457 secondary_ports,
458 _marker: PhantomData,
459 }
460 }
461}
462
463impl<T> GraphStage for MergePreferred<T>
464where
465 T: Clone + Send + 'static,
466{
467 type Shape = MergePreferredShape<T>;
468
469 fn name(&self) -> &str {
470 "MergePreferred"
471 }
472
473 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
474 let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
475 let secondary = (0..self.secondary_ports)
476 .map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
477 .collect();
478 MergePreferredShape::new(
479 preferred,
480 secondary,
481 allocator.outlet_arc(merge_preferred_outlet_name()),
482 )
483 }
484
485 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
486 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
487 }
488
489 fn stage_spec_with_ports(
490 &self,
491 _shape: &Self::Shape,
492 inlets: Vec<AnyInlet>,
493 outlets: Vec<AnyOutlet>,
494 ) -> StageSpec {
495 StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
496 }
497}
498
499#[derive(Clone, Debug)]
500pub struct MergePrioritized<T: 'static> {
501 weights: Vec<usize>,
502 _marker: PhantomData<fn() -> T>,
503}
504
505impl<T: 'static> MergePrioritized<T> {
506 #[must_use]
507 pub fn new(weights: Vec<usize>) -> Self {
508 assert!(!weights.is_empty(), "prioritized merge must have inputs");
509 assert!(
510 weights.iter().all(|weight| *weight > 0),
511 "prioritized merge weights must be greater than zero"
512 );
513 Self {
514 weights,
515 _marker: PhantomData,
516 }
517 }
518}
519
520impl<T> GraphStage for MergePrioritized<T>
521where
522 T: Clone + Send + 'static,
523{
524 type Shape = FanInShape<T, T>;
525
526 fn name(&self) -> &str {
527 "MergePrioritized"
528 }
529
530 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
531 let inlets = (0..self.weights.len())
532 .map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
533 .collect();
534 FanInShape::new(
535 inlets,
536 allocator.outlet_arc(merge_prioritized_outlet_name()),
537 )
538 }
539
540 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
541 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
542 }
543
544 fn stage_spec_with_ports(
545 &self,
546 _shape: &Self::Shape,
547 inlets: Vec<AnyInlet>,
548 outlets: Vec<AnyOutlet>,
549 ) -> StageSpec {
550 StageSpec::merge_prioritized(
551 merge_prioritized_stage_name(),
552 inlets,
553 outlets,
554 self.weights.clone(),
555 )
556 }
557}
558
559#[derive(Clone, Debug)]
560pub struct Zip<Left: 'static, Right: 'static> {
561 _marker: PhantomData<fn() -> (Left, Right)>,
562}
563
564impl<Left: 'static, Right: 'static> Zip<Left, Right> {
565 #[must_use]
566 pub fn new() -> Self {
567 Self {
568 _marker: PhantomData,
569 }
570 }
571}
572
573impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
574 fn default() -> Self {
575 Self::new()
576 }
577}
578
579impl<Left, Right> GraphStage for Zip<Left, Right>
580where
581 Left: Clone + Send + 'static,
582 Right: Clone + Send + 'static,
583{
584 type Shape = ZipShape<Left, Right>;
585
586 fn name(&self) -> &str {
587 "Zip"
588 }
589
590 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
591 let first_id = next_port_id_block(3);
592 ZipShape::new(
593 Inlet::with_arc_name(first_id, zip_in0_name()),
594 Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
595 Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
596 )
597 }
598
599 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
600 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
601 }
602
603 fn stage_spec_with_ports(
604 &self,
605 _shape: &Self::Shape,
606 inlets: Vec<AnyInlet>,
607 outlets: Vec<AnyOutlet>,
608 ) -> StageSpec {
609 let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
610 let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
611 let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
612 Ok(datum((left, right)))
613 });
614 StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
615 }
616}
617
618#[derive(Clone, Debug)]
619pub struct MergeSorted<T: 'static> {
620 _marker: PhantomData<fn() -> T>,
621}
622
623impl<T: 'static> MergeSorted<T> {
624 #[must_use]
625 pub fn new() -> Self {
626 Self {
627 _marker: PhantomData,
628 }
629 }
630}
631
632impl<T: 'static> Default for MergeSorted<T> {
633 fn default() -> Self {
634 Self::new()
635 }
636}
637
638impl<T> GraphStage for MergeSorted<T>
639where
640 T: Clone + Ord + Send + 'static,
641{
642 type Shape = FanInShape<T, T>;
643
644 fn name(&self) -> &str {
645 "MergeSorted"
646 }
647
648 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
649 let inlets = vec![
650 allocator.inlet("MergeSorted.in0"),
651 allocator.inlet("MergeSorted.in1"),
652 ];
653 FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
654 }
655
656 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
657 let compare = Arc::new(
658 move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
659 let a_t: &T = a
660 .as_any_ref()
661 .downcast_ref::<T>()
662 .expect("merge-sorted compare: wrong element type");
663 let b_t: &T = b
664 .as_any_ref()
665 .downcast_ref::<T>()
666 .expect("merge-sorted compare: wrong element type");
667 a_t.cmp(b_t)
668 },
669 );
670 StageSpec::merge_sorted(
671 Arc::from(self.name()),
672 shape.inlets(),
673 shape.outlets(),
674 compare,
675 )
676 }
677
678 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
679 struct State<T> {
680 left: VecDeque<T>,
681 right: VecDeque<T>,
682 left_closed: bool,
683 right_closed: bool,
684 pending: VecDeque<T>,
685 }
686
687 impl<T> Default for State<T> {
688 fn default() -> Self {
689 Self {
690 left: VecDeque::new(),
691 right: VecDeque::new(),
692 left_closed: false,
693 right_closed: false,
694 pending: VecDeque::new(),
695 }
696 }
697 }
698
699 fn maybe_queue_output<T>(state: &mut State<T>) -> bool
700 where
701 T: Clone + Ord,
702 {
703 let next = match (state.left.front(), state.right.front()) {
704 (Some(left), Some(right)) => {
705 if left <= right {
706 state.left.pop_front()
707 } else {
708 state.right.pop_front()
709 }
710 }
711 (Some(_), None) if state.right_closed => state.left.pop_front(),
712 (None, Some(_)) if state.left_closed => state.right.pop_front(),
713 _ => None,
714 };
715 if let Some(value) = next {
716 state.pending.push_back(value);
717 true
718 } else {
719 false
720 }
721 }
722
723 fn maybe_complete<T>(
724 logic: &mut GraphStageLogic,
725 outlet: &Outlet<T>,
726 state: &State<T>,
727 ) -> StreamResult<()>
728 where
729 T: Clone + Send + 'static,
730 {
731 if state.left_closed
732 && state.right_closed
733 && state.left.is_empty()
734 && state.right.is_empty()
735 && state.pending.is_empty()
736 && !logic.is_closed(outlet)
737 {
738 logic.complete(outlet)?;
739 }
740 Ok(())
741 }
742
743 fn maybe_pull<T>(
744 logic: &mut GraphStageLogic,
745 left: &Inlet<T>,
746 right: &Inlet<T>,
747 state: &State<T>,
748 ) -> StreamResult<()>
749 where
750 T: Clone + Send + 'static,
751 {
752 if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
753 logic.pull(left)?;
754 }
755 if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
756 logic.pull(right)?;
757 }
758 Ok(())
759 }
760
761 fn maybe_drain<T>(
762 logic: &mut GraphStageLogic,
763 outlet: &Outlet<T>,
764 state: &Arc<Mutex<State<T>>>,
765 ) -> StreamResult<()>
766 where
767 T: Clone + Ord + Send + 'static,
768 {
769 let next = if logic.is_available(outlet) {
770 state
771 .lock()
772 .expect("merge-sorted state poisoned")
773 .pending
774 .pop_front()
775 } else {
776 None
777 };
778 if let Some(value) = next {
779 logic.push(outlet, value)?;
780 }
781 Ok(())
782 }
783
784 struct In<T: 'static> {
785 inlet_id: PortId,
786 left: Inlet<T>,
787 right: Inlet<T>,
788 outlet: Outlet<T>,
789 state: Arc<Mutex<State<T>>>,
790 }
791
792 impl<T> InHandler for In<T>
793 where
794 T: Clone + Ord + Send + 'static,
795 {
796 fn on_push(
797 &mut self,
798 logic: &mut GraphStageLogic,
799 _inlet: AnyInlet,
800 ) -> StreamResult<()> {
801 let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
802 downcast_datum(value, "grab", || {
803 format!("inlet#{}", self.inlet_id.as_usize())
804 })
805 })?;
806 {
807 let mut state = self.state.lock().expect("merge-sorted state poisoned");
808 if self.inlet_id == self.left.id() {
809 state.left.push_back(value);
810 } else {
811 state.right.push_back(value);
812 }
813 while maybe_queue_output(&mut state) {}
814 }
815 maybe_drain(logic, &self.outlet, &self.state)?;
816 let state = self.state.lock().expect("merge-sorted state poisoned");
817 maybe_complete(logic, &self.outlet, &state)?;
818 maybe_pull(logic, &self.left, &self.right, &state)
819 }
820
821 fn on_upstream_finish(
822 &mut self,
823 logic: &mut GraphStageLogic,
824 _inlet: AnyInlet,
825 ) -> StreamResult<()> {
826 {
827 let mut state = self.state.lock().expect("merge-sorted state poisoned");
828 if self.inlet_id == self.left.id() {
829 state.left_closed = true;
830 } else {
831 state.right_closed = true;
832 }
833 while maybe_queue_output(&mut state) {}
834 }
835 maybe_drain(logic, &self.outlet, &self.state)?;
836 let state = self.state.lock().expect("merge-sorted state poisoned");
837 maybe_complete(logic, &self.outlet, &state)?;
838 maybe_pull(logic, &self.left, &self.right, &state)
839 }
840 }
841
842 struct Out<T: 'static> {
843 left: Inlet<T>,
844 right: Inlet<T>,
845 outlet: Outlet<T>,
846 state: Arc<Mutex<State<T>>>,
847 }
848
849 impl<T> OutHandler for Out<T>
850 where
851 T: Clone + Ord + Send + 'static,
852 {
853 fn on_pull(
854 &mut self,
855 logic: &mut GraphStageLogic,
856 _outlet: AnyOutlet,
857 ) -> StreamResult<()> {
858 maybe_drain(logic, &self.outlet, &self.state)?;
859 let state = self.state.lock().expect("merge-sorted state poisoned");
860 maybe_complete(logic, &self.outlet, &state)?;
861 maybe_pull(logic, &self.left, &self.right, &state)
862 }
863 }
864
865 let state = Arc::new(Mutex::new(State::<T>::default()));
866 let left = shape.inlet(0).expect("merge-sorted left inlet");
867 let right = shape.inlet(1).expect("merge-sorted right inlet");
868 let outlet = shape.outlet();
869 let mut logic = GraphStageLogic::new(shape);
870 logic
871 .set_handler(
872 &left,
873 Box::new(In {
874 inlet_id: left.id(),
875 left: left.clone(),
876 right: right.clone(),
877 outlet: outlet.clone(),
878 state: Arc::clone(&state),
879 }),
880 )
881 .unwrap();
882 logic
883 .set_handler(
884 &right,
885 Box::new(In {
886 inlet_id: right.id(),
887 left: left.clone(),
888 right: right.clone(),
889 outlet: outlet.clone(),
890 state: Arc::clone(&state),
891 }),
892 )
893 .unwrap();
894 logic
895 .set_out_handler(
896 &outlet.clone(),
897 Box::new(Out {
898 left,
899 right,
900 outlet: outlet.clone(),
901 state,
902 }),
903 )
904 .unwrap();
905 logic
906 }
907}
908
909#[derive(Clone)]
910pub struct MergeSequence<T: 'static> {
911 inputs: usize,
912 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
913 _marker: PhantomData<fn() -> T>,
914}
915
916impl<T: 'static> fmt::Debug for MergeSequence<T> {
917 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
918 f.debug_struct("MergeSequence")
919 .field("inputs", &self.inputs)
920 .finish_non_exhaustive()
921 }
922}
923
924impl<T: 'static> MergeSequence<T> {
925 #[must_use]
926 pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
927 where
928 F: Fn(&T) -> u64 + Send + Sync + 'static,
929 {
930 assert!(
931 inputs > 1,
932 "merge sequence input count must be greater than one"
933 );
934 Self {
935 inputs,
936 extract_sequence: Arc::new(extract_sequence),
937 _marker: PhantomData,
938 }
939 }
940}
941
942impl<T> GraphStage for MergeSequence<T>
943where
944 T: Clone + Send + 'static,
945{
946 type Shape = FanInShape<T, T>;
947
948 fn name(&self) -> &str {
949 "MergeSequence"
950 }
951
952 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
953 let inlets = (0..self.inputs)
954 .map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
955 .collect();
956 FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
957 }
958
959 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
960 let extract = Arc::clone(&self.extract_sequence);
961 let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
962 let t: &T = dv
963 .as_any_ref()
964 .downcast_ref::<T>()
965 .expect("merge-sequence extract: wrong element type");
966 extract(t)
967 });
968 let typed_extract_fn = Arc::clone(&self.extract_sequence);
971 let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
972 let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
973 StageSpec::merge_sequence(
974 Arc::from(self.name()),
975 shape.inlets(),
976 shape.outlets(),
977 self.inputs,
978 extract_sequence,
979 typed_extract,
980 )
981 }
982
983 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
984 #[derive(Clone)]
985 struct Pending<T> {
986 sequence: u64,
987 elem: T,
988 }
989
990 struct State<T> {
991 next_sequence: u64,
992 pending: Vec<Pending<T>>,
993 completed: usize,
994 pending_output: VecDeque<T>,
995 }
996
997 fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
998 where
999 T: Clone + Send + 'static,
1000 {
1001 while let Some(index) = state
1002 .pending
1003 .iter()
1004 .position(|item| item.sequence == state.next_sequence)
1005 {
1006 let item = state.pending.remove(index);
1007 if state
1008 .pending
1009 .iter()
1010 .any(|other| other.sequence == state.next_sequence)
1011 {
1012 return Err(StreamError::Failed(format!(
1013 "duplicate sequence {} on merge sequence",
1014 state.next_sequence
1015 )));
1016 }
1017 state.pending_output.push_back(item.elem);
1018 state.next_sequence += 1;
1019 }
1020 Ok(())
1021 }
1022
1023 struct In<T: 'static> {
1024 inlet_id: PortId,
1025 inlet_index: usize,
1026 inlet: Inlet<T>,
1027 all_inlets: Vec<Inlet<T>>,
1028 outlet: Outlet<T>,
1029 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1030 state: Arc<Mutex<State<T>>>,
1031 }
1032
1033 impl<T> InHandler for In<T>
1034 where
1035 T: Clone + Send + 'static,
1036 {
1037 fn on_push(
1038 &mut self,
1039 logic: &mut GraphStageLogic,
1040 _inlet: AnyInlet,
1041 ) -> StreamResult<()> {
1042 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1043 downcast_datum(value, "grab", || {
1044 format!("inlet#{}", self.inlet_id.as_usize())
1045 })
1046 })?;
1047 {
1048 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1049 let sequence = (self.extract_sequence)(&elem);
1050 if sequence < state.next_sequence {
1051 return Err(StreamError::Failed(format!(
1052 "sequence regression from {} to {} on port {}",
1053 state.next_sequence, sequence, self.inlet_index
1054 )));
1055 }
1056 state.pending.push(Pending { sequence, elem });
1057 try_emit_pending(&mut state)?;
1058 }
1059 let next = if logic.is_available(&self.outlet) {
1060 self.state
1061 .lock()
1062 .expect("merge-sequence state poisoned")
1063 .pending_output
1064 .pop_front()
1065 } else {
1066 None
1067 };
1068 if let Some(value) = next {
1069 logic.push(&self.outlet, value)?;
1070 }
1071 let state = self.state.lock().expect("merge-sequence state poisoned");
1072 if state.completed == self.all_inlets.len()
1073 && state.pending.is_empty()
1074 && state.pending_output.is_empty()
1075 {
1076 logic.complete(&self.outlet)?;
1077 } else if logic.is_available(&self.outlet)
1078 && state.pending_output.is_empty()
1079 && state.pending.len() + state.completed == self.all_inlets.len()
1080 {
1081 return Err(StreamError::Failed(format!(
1082 "expected sequence {}, but all input ports have pushed or are complete",
1083 state.next_sequence
1084 )));
1085 }
1086 if !logic.has_been_pulled(&self.inlet) {
1087 logic.pull(&self.inlet)?;
1088 }
1089 Ok(())
1090 }
1091
1092 fn on_upstream_finish(
1093 &mut self,
1094 logic: &mut GraphStageLogic,
1095 _inlet: AnyInlet,
1096 ) -> StreamResult<()> {
1097 {
1098 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1099 state.completed += 1;
1100 }
1101 let state = self.state.lock().expect("merge-sequence state poisoned");
1102 if state.completed == self.all_inlets.len()
1103 && state.pending.is_empty()
1104 && state.pending_output.is_empty()
1105 {
1106 logic.complete(&self.outlet)?;
1107 } else if logic.is_available(&self.outlet)
1108 && state.pending_output.is_empty()
1109 && state.pending.len() + state.completed == self.all_inlets.len()
1110 {
1111 return Err(StreamError::Failed(format!(
1112 "expected sequence {}, but all input ports have pushed or are complete",
1113 state.next_sequence
1114 )));
1115 }
1116 Ok(())
1117 }
1118 }
1119
1120 struct Out<T: 'static> {
1121 inlets: Vec<Inlet<T>>,
1122 outlet: Outlet<T>,
1123 state: Arc<Mutex<State<T>>>,
1124 }
1125
1126 impl<T> OutHandler for Out<T>
1127 where
1128 T: Clone + Send + 'static,
1129 {
1130 fn on_pull(
1131 &mut self,
1132 logic: &mut GraphStageLogic,
1133 _outlet: AnyOutlet,
1134 ) -> StreamResult<()> {
1135 let next = self
1136 .state
1137 .lock()
1138 .expect("merge-sequence state poisoned")
1139 .pending_output
1140 .pop_front();
1141 if let Some(value) = next {
1142 logic.push(&self.outlet, value)?;
1143 } else {
1144 let state = self.state.lock().expect("merge-sequence state poisoned");
1145 if state.completed == self.inlets.len() && state.pending.is_empty() {
1146 logic.complete(&self.outlet)?;
1147 } else if state.pending.len() + state.completed == self.inlets.len() {
1148 return Err(StreamError::Failed(format!(
1149 "expected sequence {}, but all input ports have pushed or are complete",
1150 state.next_sequence
1151 )));
1152 }
1153 }
1154 for inlet in &self.inlets {
1155 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1156 logic.pull(inlet)?;
1157 }
1158 }
1159 Ok(())
1160 }
1161 }
1162
1163 let inlets = shape.inlets_vec();
1164 let outlet = shape.outlet();
1165 let state = Arc::new(Mutex::new(State {
1166 next_sequence: 0,
1167 pending: Vec::new(),
1168 completed: 0,
1169 pending_output: VecDeque::new(),
1170 }));
1171 let mut logic = GraphStageLogic::new(shape);
1172 for (index, inlet) in inlets.iter().cloned().enumerate() {
1173 logic
1174 .set_handler(
1175 &inlet.clone(),
1176 Box::new(In {
1177 inlet_id: inlet.id(),
1178 inlet_index: index,
1179 inlet: inlet.clone(),
1180 all_inlets: inlets.clone(),
1181 outlet: outlet.clone(),
1182 extract_sequence: Arc::clone(&self.extract_sequence),
1183 state: Arc::clone(&state),
1184 }),
1185 )
1186 .unwrap();
1187 }
1188 logic
1189 .set_out_handler(
1190 &outlet.clone(),
1191 Box::new(Out {
1192 inlets,
1193 outlet: outlet.clone(),
1194 state,
1195 }),
1196 )
1197 .unwrap();
1198 logic
1199 }
1200}
1201
1202#[derive(Clone, Debug)]
1203pub struct MergeLatest<T: 'static> {
1204 inputs: usize,
1205 eager_complete: bool,
1206 _marker: PhantomData<fn() -> T>,
1207}
1208
1209impl<T: 'static> MergeLatest<T> {
1210 #[must_use]
1211 pub fn new(inputs: usize, eager_complete: bool) -> Self {
1212 assert!(
1213 inputs > 0,
1214 "merge-latest input count must be greater than zero"
1215 );
1216 Self {
1217 inputs,
1218 eager_complete,
1219 _marker: PhantomData,
1220 }
1221 }
1222}
1223
1224impl<T> GraphStage for MergeLatest<T>
1225where
1226 T: Clone + Send + 'static,
1227{
1228 type Shape = FanInShape<T, Vec<T>>;
1229
1230 fn name(&self) -> &str {
1231 "MergeLatest"
1232 }
1233
1234 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1235 let inlets = (0..self.inputs)
1236 .map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
1237 .collect();
1238 FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
1239 }
1240
1241 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1242 let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
1243 let snapshot: Vec<T> = values
1244 .iter()
1245 .map(|dv| {
1246 dv.as_any_ref()
1247 .downcast_ref::<T>()
1248 .cloned()
1249 .expect("merge-latest snapshot: wrong element type")
1250 })
1251 .collect();
1252 datum(snapshot)
1253 });
1254 #[allow(clippy::type_complexity)]
1257 let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
1258 Arc::new(move |slots: &[Option<T>]| {
1259 slots
1260 .iter()
1261 .map(|s| {
1262 s.clone()
1263 .expect("merge-latest typed snapshot: slot is None")
1264 })
1265 .collect()
1266 });
1267 let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
1268 StageSpec::merge_latest(
1269 Arc::from(self.name()),
1270 shape.inlets(),
1271 shape.outlets(),
1272 self.inputs,
1273 self.eager_complete,
1274 build_snapshot,
1275 typed_snapshot,
1276 )
1277 }
1278
1279 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1280 struct State<T> {
1281 latest: Vec<Option<T>>,
1282 seen: usize,
1283 completed: usize,
1284 pending: VecDeque<Vec<T>>,
1285 eager_complete: bool,
1286 }
1287
1288 struct In<T: 'static> {
1289 inlet_id: PortId,
1290 inlet_index: usize,
1291 inlet: Inlet<T>,
1292 all_inlets: Vec<Inlet<T>>,
1293 outlet: Outlet<Vec<T>>,
1294 state: Arc<Mutex<State<T>>>,
1295 }
1296
1297 impl<T> InHandler for In<T>
1298 where
1299 T: Clone + Send + 'static,
1300 {
1301 fn on_push(
1302 &mut self,
1303 logic: &mut GraphStageLogic,
1304 _inlet: AnyInlet,
1305 ) -> StreamResult<()> {
1306 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1307 downcast_datum(value, "grab", || {
1308 format!("inlet#{}", self.inlet_id.as_usize())
1309 })
1310 })?;
1311 {
1312 let mut state = self.state.lock().expect("merge-latest state poisoned");
1313 if state.latest[self.inlet_index].is_none() {
1314 state.seen += 1;
1315 }
1316 state.latest[self.inlet_index] = Some(elem);
1317 if state.seen == state.latest.len() {
1318 let snapshot = state
1319 .latest
1320 .iter()
1321 .map(|item| item.clone().expect("merge-latest seen"))
1322 .collect();
1323 state.pending.push_back(snapshot);
1324 }
1325 }
1326 let next = if logic.is_available(&self.outlet) {
1327 self.state
1328 .lock()
1329 .expect("merge-latest state poisoned")
1330 .pending
1331 .pop_front()
1332 } else {
1333 None
1334 };
1335 if let Some(value) = next {
1336 logic.push(&self.outlet, value)?;
1337 }
1338 if !logic.has_been_pulled(&self.inlet) {
1339 logic.pull(&self.inlet)?;
1340 }
1341 Ok(())
1342 }
1343
1344 fn on_upstream_finish(
1345 &mut self,
1346 logic: &mut GraphStageLogic,
1347 _inlet: AnyInlet,
1348 ) -> StreamResult<()> {
1349 let state = {
1350 let mut state = self.state.lock().expect("merge-latest state poisoned");
1351 state.completed += 1;
1352 (
1353 state.completed == self.all_inlets.len(),
1354 state.eager_complete,
1355 state.pending.is_empty(),
1356 )
1357 };
1358 if state.0 || (state.1 && state.2) {
1359 logic.complete(&self.outlet)?;
1360 }
1361 Ok(())
1362 }
1363 }
1364
1365 struct Out<T: 'static> {
1366 inlets: Vec<Inlet<T>>,
1367 outlet: Outlet<Vec<T>>,
1368 state: Arc<Mutex<State<T>>>,
1369 }
1370
1371 impl<T> OutHandler for Out<T>
1372 where
1373 T: Clone + Send + 'static,
1374 {
1375 fn on_pull(
1376 &mut self,
1377 logic: &mut GraphStageLogic,
1378 _outlet: AnyOutlet,
1379 ) -> StreamResult<()> {
1380 let next = self
1381 .state
1382 .lock()
1383 .expect("merge-latest state poisoned")
1384 .pending
1385 .pop_front();
1386 if let Some(value) = next {
1387 logic.push(&self.outlet, value)?;
1388 } else {
1389 let state = self.state.lock().expect("merge-latest state poisoned");
1390 if state.completed == self.inlets.len()
1391 || (state.eager_complete && state.completed > 0)
1392 {
1393 logic.complete(&self.outlet)?;
1394 }
1395 }
1396 for inlet in &self.inlets {
1397 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1398 logic.pull(inlet)?;
1399 }
1400 }
1401 Ok(())
1402 }
1403 }
1404
1405 let inlets = shape.inlets_vec();
1406 let outlet = shape.outlet();
1407 let state = Arc::new(Mutex::new(State {
1408 latest: vec![None; inlets.len()],
1409 seen: 0,
1410 completed: 0,
1411 pending: VecDeque::new(),
1412 eager_complete: self.eager_complete,
1413 }));
1414 let mut logic = GraphStageLogic::new(shape);
1415 for (index, inlet) in inlets.iter().cloned().enumerate() {
1416 logic
1417 .set_handler(
1418 &inlet.clone(),
1419 Box::new(In {
1420 inlet_id: inlet.id(),
1421 inlet_index: index,
1422 inlet: inlet.clone(),
1423 all_inlets: inlets.clone(),
1424 outlet: outlet.clone(),
1425 state: Arc::clone(&state),
1426 }),
1427 )
1428 .unwrap();
1429 }
1430 logic
1431 .set_out_handler(
1432 &outlet.clone(),
1433 Box::new(Out {
1434 inlets,
1435 outlet: outlet.clone(),
1436 state,
1437 }),
1438 )
1439 .unwrap();
1440 logic
1441 }
1442}
1443
1444#[derive(Clone)]
1445pub struct Partition<T: 'static> {
1446 outputs: usize,
1447 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1448 eager_cancel: bool,
1449 _marker: PhantomData<fn() -> T>,
1450}
1451
1452impl<T: 'static> fmt::Debug for Partition<T> {
1453 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1454 f.debug_struct("Partition")
1455 .field("outputs", &self.outputs)
1456 .field("eager_cancel", &self.eager_cancel)
1457 .finish_non_exhaustive()
1458 }
1459}
1460
1461impl<T: 'static> Partition<T> {
1462 #[must_use]
1463 pub fn new<F>(outputs: usize, partitioner: F) -> Self
1464 where
1465 F: Fn(&T) -> usize + Send + Sync + 'static,
1466 {
1467 Self::new_with_eager_cancel(outputs, partitioner, false)
1468 }
1469
1470 #[must_use]
1471 pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
1472 where
1473 F: Fn(&T) -> usize + Send + Sync + 'static,
1474 {
1475 assert!(
1476 outputs > 0,
1477 "partition output count must be greater than zero"
1478 );
1479 Self {
1480 outputs,
1481 partitioner: Arc::new(partitioner),
1482 eager_cancel,
1483 _marker: PhantomData,
1484 }
1485 }
1486}
1487
1488impl<T> GraphStage for Partition<T>
1489where
1490 T: Clone + Send + 'static,
1491{
1492 type Shape = FanOutShape<T, T>;
1493
1494 fn name(&self) -> &str {
1495 "Partition"
1496 }
1497
1498 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1499 let inlet = allocator.inlet("Partition.in");
1500 let outlets = (0..self.outputs)
1501 .map(|index| allocator.outlet(format!("Partition.out{index}")))
1502 .collect();
1503 FanOutShape::new(inlet, outlets)
1504 }
1505
1506 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1507 let partitioner_clone = Arc::clone(&self.partitioner);
1508 let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
1509 let t: &T = dv
1510 .as_any_ref()
1511 .downcast_ref::<T>()
1512 .expect("partition: wrong element type");
1513 partitioner_clone(t)
1514 });
1515 StageSpec::partition(
1516 Arc::from(self.name()),
1517 shape.inlets(),
1518 shape.outlets(),
1519 self.outputs,
1520 partitioner,
1521 self.eager_cancel,
1522 )
1523 }
1524
1525 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1526 struct State<T> {
1527 pending: Option<(usize, T)>,
1528 upstream_closed: bool,
1529 live_outlets: usize,
1530 cancelled: Vec<bool>,
1531 eager_cancel: bool,
1532 }
1533
1534 fn any_live_demand<T>(
1535 logic: &GraphStageLogic,
1536 outlets: &[Outlet<T>],
1537 cancelled: &[bool],
1538 ) -> bool
1539 where
1540 T: Clone + Send + 'static,
1541 {
1542 outlets
1543 .iter()
1544 .enumerate()
1545 .any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
1546 }
1547
1548 struct In<T: 'static> {
1549 inlet_id: PortId,
1550 inlet: Inlet<T>,
1551 outlets: Vec<Outlet<T>>,
1552 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1553 state: Arc<Mutex<State<T>>>,
1554 }
1555
1556 impl<T> InHandler for In<T>
1557 where
1558 T: Clone + Send + 'static,
1559 {
1560 fn on_push(
1561 &mut self,
1562 logic: &mut GraphStageLogic,
1563 _inlet: AnyInlet,
1564 ) -> StreamResult<()> {
1565 let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1566 downcast_datum(value, "grab", || {
1567 format!("inlet#{}", self.inlet_id.as_usize())
1568 })
1569 })?;
1570 let idx = (self.partitioner)(&item);
1571 if idx >= self.outlets.len() {
1572 return Err(StreamError::Failed(format!(
1573 "partitioner returned out-of-bounds index {idx} for {} outputs",
1574 self.outlets.len()
1575 )));
1576 }
1577 let mut pull_again = false;
1578 {
1579 let mut state = self.state.lock().expect("partition state poisoned");
1580 if state.cancelled[idx] {
1581 pull_again = !state.upstream_closed
1582 && any_live_demand(logic, &self.outlets, &state.cancelled);
1583 } else if logic.is_available(&self.outlets[idx]) {
1584 logic.push(&self.outlets[idx], item)?;
1585 pull_again = !state.upstream_closed
1586 && any_live_demand(logic, &self.outlets, &state.cancelled);
1587 } else {
1588 state.pending = Some((idx, item));
1589 }
1590 }
1591 if pull_again && !logic.has_been_pulled(&self.inlet) {
1592 logic.pull(&self.inlet)?;
1593 }
1594 Ok(())
1595 }
1596
1597 fn on_upstream_finish(
1598 &mut self,
1599 logic: &mut GraphStageLogic,
1600 _inlet: AnyInlet,
1601 ) -> StreamResult<()> {
1602 let complete_now = {
1603 let mut state = self.state.lock().expect("partition state poisoned");
1604 state.upstream_closed = true;
1605 state.pending.is_none()
1606 };
1607 if complete_now {
1608 for outlet in &self.outlets {
1609 if !logic.is_closed(outlet) {
1610 logic.complete(outlet)?;
1611 }
1612 }
1613 }
1614 Ok(())
1615 }
1616 }
1617
1618 struct Out<T: 'static> {
1619 index: usize,
1620 inlet: Inlet<T>,
1621 outlets: Vec<Outlet<T>>,
1622 state: Arc<Mutex<State<T>>>,
1623 }
1624
1625 impl<T> OutHandler for Out<T>
1626 where
1627 T: Clone + Send + 'static,
1628 {
1629 fn on_pull(
1630 &mut self,
1631 logic: &mut GraphStageLogic,
1632 _outlet: AnyOutlet,
1633 ) -> StreamResult<()> {
1634 let mut complete_now = false;
1635 let pending = {
1636 let mut state = self.state.lock().expect("partition state poisoned");
1637 if let Some((idx, _)) = &state.pending
1638 && *idx == self.index
1639 {
1640 state.pending.take()
1641 } else {
1642 None
1643 }
1644 };
1645 if let Some((_, item)) = pending {
1646 logic.push(&self.outlets[self.index], item)?;
1647 let state = self.state.lock().expect("partition state poisoned");
1648 if state.upstream_closed {
1649 complete_now = true;
1650 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
1651 && !logic.has_been_pulled(&self.inlet)
1652 {
1653 logic.pull(&self.inlet)?;
1654 }
1655 } else {
1656 let state = self.state.lock().expect("partition state poisoned");
1657 if state.upstream_closed {
1658 complete_now = true;
1659 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
1660 && !logic.has_been_pulled(&self.inlet)
1661 {
1662 logic.pull(&self.inlet)?;
1663 }
1664 }
1665 if complete_now {
1666 for outlet in &self.outlets {
1667 if !logic.is_closed(outlet) {
1668 logic.complete(outlet)?;
1669 }
1670 }
1671 }
1672 Ok(())
1673 }
1674
1675 fn on_downstream_finish(
1676 &mut self,
1677 logic: &mut GraphStageLogic,
1678 _outlet: AnyOutlet,
1679 ) -> StreamResult<()> {
1680 let (cancel_stage, clear_pending) = {
1681 let mut state = self.state.lock().expect("partition state poisoned");
1682 if state.cancelled[self.index] {
1683 return Ok(());
1684 }
1685 state.cancelled[self.index] = true;
1686 state.live_outlets -= 1;
1687 let clear_pending = state
1688 .pending
1689 .as_ref()
1690 .is_some_and(|(idx, _)| *idx == self.index);
1691 let cancel_stage = state.eager_cancel || state.live_outlets == 0;
1692 if clear_pending {
1693 state.pending = None;
1694 }
1695 (cancel_stage, clear_pending)
1696 };
1697 if cancel_stage {
1698 logic.complete_stage()?;
1699 } else if clear_pending
1700 && !logic.has_been_pulled(&self.inlet)
1701 && !logic.is_closed(&self.inlet)
1702 {
1703 let state = self.state.lock().expect("partition state poisoned");
1704 if any_live_demand(logic, &self.outlets, &state.cancelled) {
1705 logic.pull(&self.inlet)?;
1706 }
1707 }
1708 Ok(())
1709 }
1710 }
1711
1712 let inlet = shape.inlet();
1713 let outlets = shape.outlets_vec();
1714 let state = Arc::new(Mutex::new(State {
1715 pending: None,
1716 upstream_closed: false,
1717 live_outlets: outlets.len(),
1718 cancelled: vec![false; outlets.len()],
1719 eager_cancel: self.eager_cancel,
1720 }));
1721 let mut logic = GraphStageLogic::new(shape);
1722 logic
1723 .set_handler(
1724 &inlet,
1725 Box::new(In {
1726 inlet_id: inlet.id(),
1727 inlet: inlet.clone(),
1728 outlets: outlets.clone(),
1729 partitioner: Arc::clone(&self.partitioner),
1730 state: Arc::clone(&state),
1731 }),
1732 )
1733 .unwrap();
1734 for (index, outlet) in outlets.iter().cloned().enumerate() {
1735 logic
1736 .set_out_handler(
1737 &outlet,
1738 Box::new(Out {
1739 index,
1740 inlet: inlet.clone(),
1741 outlets: outlets.clone(),
1742 state: Arc::clone(&state),
1743 }),
1744 )
1745 .unwrap();
1746 }
1747 logic
1748 }
1749}
1750
1751#[derive(Clone, Debug)]
1752pub struct Unzip<A: 'static, B: 'static> {
1753 _marker: PhantomData<fn() -> (A, B)>,
1754}
1755
1756impl<A: 'static, B: 'static> Unzip<A, B> {
1757 #[must_use]
1758 pub fn new() -> Self {
1759 Self {
1760 _marker: PhantomData,
1761 }
1762 }
1763}
1764
1765impl<A: 'static, B: 'static> Default for Unzip<A, B> {
1766 fn default() -> Self {
1767 Self::new()
1768 }
1769}
1770
1771impl<A, B> GraphStage for Unzip<A, B>
1772where
1773 A: Clone + Send + 'static,
1774 B: Clone + Send + 'static,
1775{
1776 type Shape = FanOutShape2<(A, B), A, B>;
1777
1778 fn name(&self) -> &str {
1779 "Unzip"
1780 }
1781
1782 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1783 FanOutShape2::new(
1784 allocator.inlet("Unzip.in"),
1785 allocator.outlet("Unzip.out0"),
1786 allocator.outlet("Unzip.out1"),
1787 )
1788 }
1789
1790 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1791 let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
1792 let pair: (A, B) =
1793 downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
1794 (datum(pair.0), datum(pair.1))
1795 });
1796 #[allow(clippy::type_complexity)]
1801 let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
1802 Arc::new(|pair: (A, B)| pair);
1803 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
1804 StageSpec::unzip(
1805 Arc::from(self.name()),
1806 shape.inlets(),
1807 shape.outlets(),
1808 split,
1809 typed_split,
1810 )
1811 }
1812
1813 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1814 UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
1815 }
1816}
1817
1818#[derive(Clone)]
1819pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
1820 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
1821 _marker: PhantomData<fn(In) -> (Out0, Out1)>,
1822}
1823
1824impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
1825 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1826 f.debug_struct("UnzipWith").finish_non_exhaustive()
1827 }
1828}
1829
1830impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
1831 #[must_use]
1832 pub fn new<F>(split: F) -> Self
1833 where
1834 F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
1835 {
1836 Self {
1837 split: Arc::new(split),
1838 _marker: PhantomData,
1839 }
1840 }
1841}
1842
1843impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
1844where
1845 In: Clone + Send + 'static,
1846 Out0: Clone + Send + 'static,
1847 Out1: Clone + Send + 'static,
1848{
1849 type Shape = FanOutShape2<In, Out0, Out1>;
1850
1851 fn name(&self) -> &str {
1852 "UnzipWith"
1853 }
1854
1855 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1856 FanOutShape2::new(
1857 allocator.inlet("UnzipWith.in"),
1858 allocator.outlet("UnzipWith.out0"),
1859 allocator.outlet("UnzipWith.out1"),
1860 )
1861 }
1862
1863 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1864 let split_fn = Arc::clone(&self.split);
1865 let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
1866 let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
1867 .expect("unzip-with: wrong element type");
1868 let (out0, out1) = split_fn(value);
1869 (datum(out0), datum(out1))
1870 });
1871 let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
1874 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
1875 StageSpec::unzip(
1876 Arc::from(self.name()),
1877 shape.inlets(),
1878 shape.outlets(),
1879 split,
1880 typed_split,
1881 )
1882 }
1883
1884 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1885 struct State {
1886 left_open: bool,
1887 right_open: bool,
1888 upstream_closed: bool,
1889 }
1890
1891 struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
1892 inlet_id: PortId,
1893 inlet: Inlet<In>,
1894 out0: Outlet<Out0>,
1895 out1: Outlet<Out1>,
1896 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
1897 state: Arc<Mutex<State>>,
1898 }
1899
1900 impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
1901 where
1902 In: Clone + Send + 'static,
1903 Out0: Clone + Send + 'static,
1904 Out1: Clone + Send + 'static,
1905 {
1906 fn on_push(
1907 &mut self,
1908 logic: &mut GraphStageLogic,
1909 _inlet: AnyInlet,
1910 ) -> StreamResult<()> {
1911 let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
1912 downcast_datum(value, "grab", || {
1913 format!("inlet#{}", self.inlet_id.as_usize())
1914 })
1915 })?;
1916 let (left, right) = (self.split)(value);
1917 let state = self.state.lock().expect("unzip-with state poisoned");
1918 if state.left_open {
1919 logic.push(&self.out0, left)?;
1920 }
1921 if state.right_open {
1922 logic.push(&self.out1, right)?;
1923 }
1924 drop(state);
1925 let state = self.state.lock().expect("unzip-with state poisoned");
1926 let left_ready = !state.left_open || logic.is_available(&self.out0);
1927 let right_ready = !state.right_open || logic.is_available(&self.out1);
1928 if (state.left_open || state.right_open)
1929 && left_ready
1930 && right_ready
1931 && !logic.has_been_pulled(&self.inlet)
1932 {
1933 logic.pull(&self.inlet)?;
1934 }
1935 Ok(())
1936 }
1937
1938 fn on_upstream_finish(
1939 &mut self,
1940 logic: &mut GraphStageLogic,
1941 _inlet: AnyInlet,
1942 ) -> StreamResult<()> {
1943 self.state
1944 .lock()
1945 .expect("unzip-with state poisoned")
1946 .upstream_closed = true;
1947 if !logic.is_closed(&self.out0) {
1948 logic.complete(&self.out0)?;
1949 }
1950 if !logic.is_closed(&self.out1) {
1951 logic.complete(&self.out1)?;
1952 }
1953 Ok(())
1954 }
1955 }
1956
1957 struct Out<In: 'static, Out0: 'static, Out1: 'static> {
1958 is_left: bool,
1959 inlet: Inlet<In>,
1960 out0: Outlet<Out0>,
1961 out1: Outlet<Out1>,
1962 state: Arc<Mutex<State>>,
1963 }
1964
1965 impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
1966 where
1967 In: Clone + Send + 'static,
1968 Out0: Clone + Send + 'static,
1969 Out1: Clone + Send + 'static,
1970 {
1971 fn on_pull(
1972 &mut self,
1973 logic: &mut GraphStageLogic,
1974 _outlet: AnyOutlet,
1975 ) -> StreamResult<()> {
1976 let state = self.state.lock().expect("unzip-with state poisoned");
1977 let left_ready = !state.left_open || logic.is_available(&self.out0);
1978 let right_ready = !state.right_open || logic.is_available(&self.out1);
1979 if state.upstream_closed {
1980 drop(state);
1981 if !logic.is_closed(&self.out0) {
1982 logic.complete(&self.out0)?;
1983 }
1984 if !logic.is_closed(&self.out1) {
1985 logic.complete(&self.out1)?;
1986 }
1987 } else if (state.left_open || state.right_open)
1988 && left_ready
1989 && right_ready
1990 && !logic.has_been_pulled(&self.inlet)
1991 {
1992 drop(state);
1993 logic.pull(&self.inlet)?;
1994 }
1995 Ok(())
1996 }
1997
1998 fn on_downstream_finish(
1999 &mut self,
2000 logic: &mut GraphStageLogic,
2001 _outlet: AnyOutlet,
2002 ) -> StreamResult<()> {
2003 let mut state = self.state.lock().expect("unzip-with state poisoned");
2004 if self.is_left {
2005 state.left_open = false;
2006 } else {
2007 state.right_open = false;
2008 }
2009 if !state.left_open && !state.right_open {
2010 logic.complete_stage()?;
2011 return Ok(());
2012 }
2013 let left_ready = !state.left_open || logic.is_available(&self.out0);
2014 let right_ready = !state.right_open || logic.is_available(&self.out1);
2015 if !state.upstream_closed
2016 && (state.left_open || state.right_open)
2017 && left_ready
2018 && right_ready
2019 && !logic.has_been_pulled(&self.inlet)
2020 {
2021 logic.pull(&self.inlet)?;
2022 }
2023 Ok(())
2024 }
2025 }
2026
2027 let inlet = shape.inlet();
2028 let out0 = shape.out0();
2029 let out1 = shape.out1();
2030 let state = Arc::new(Mutex::new(State {
2031 left_open: true,
2032 right_open: true,
2033 upstream_closed: false,
2034 }));
2035 let mut logic = GraphStageLogic::new(shape);
2036 logic
2037 .set_handler(
2038 &inlet,
2039 Box::new(InHandlerState {
2040 inlet_id: inlet.id(),
2041 inlet: inlet.clone(),
2042 out0: out0.clone(),
2043 out1: out1.clone(),
2044 split: Arc::clone(&self.split),
2045 state: Arc::clone(&state),
2046 }),
2047 )
2048 .unwrap();
2049 logic
2050 .set_out_handler(
2051 &out0,
2052 Box::new(Out {
2053 is_left: true,
2054 inlet: inlet.clone(),
2055 out0: out0.clone(),
2056 out1: out1.clone(),
2057 state: Arc::clone(&state),
2058 }),
2059 )
2060 .unwrap();
2061 logic
2062 .set_out_handler(
2063 &out1.clone(),
2064 Box::new(Out {
2065 is_left: false,
2066 inlet: inlet.clone(),
2067 out0: out0.clone(),
2068 out1: out1.clone(),
2069 state,
2070 }),
2071 )
2072 .unwrap();
2073 logic
2074 }
2075}
2076
2077#[derive(Clone, Debug)]
2078pub struct AsyncBoundary<T: 'static> {
2079 _marker: PhantomData<fn() -> T>,
2080}
2081
2082impl<T: 'static> AsyncBoundary<T> {
2083 #[must_use]
2084 pub fn new() -> Self {
2085 Self {
2086 _marker: PhantomData,
2087 }
2088 }
2089}
2090
2091impl<T: 'static> Default for AsyncBoundary<T> {
2092 fn default() -> Self {
2093 Self::new()
2094 }
2095}
2096
2097impl<T> GraphStage for AsyncBoundary<T>
2098where
2099 T: Clone + Send + 'static,
2100{
2101 type Shape = FlowShape<T, T>;
2102
2103 fn name(&self) -> &str {
2104 "AsyncBoundary"
2105 }
2106
2107 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
2108 let first_id = next_port_id_block(2);
2109 FlowShape::new(
2110 Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
2111 Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
2112 )
2113 }
2114
2115 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2116 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
2117 }
2118
2119 fn stage_spec_with_ports(
2120 &self,
2121 _shape: &Self::Shape,
2122 inlets: Vec<AnyInlet>,
2123 outlets: Vec<AnyOutlet>,
2124 ) -> StageSpec {
2125 StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
2126 }
2127}