1use super::*;
2
3#[derive(Clone, Debug)]
4pub struct Identity<T: 'static> {
5 _marker: PhantomData<fn() -> T>,
6}
7
8impl<T: 'static> Identity<T> {
9 #[must_use]
10 pub fn new() -> Self {
11 Self {
12 _marker: PhantomData,
13 }
14 }
15}
16
17impl<T: 'static> Default for Identity<T> {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl<T> GraphStage for Identity<T>
24where
25 T: Clone + Send + 'static,
26{
27 type Shape = FlowShape<T, T>;
28
29 fn name(&self) -> &str {
30 "Identity"
31 }
32
33 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
34 let first_id = next_port_id_block(2);
35 FlowShape::new(
36 Inlet::with_arc_name(first_id, identity_inlet_name()),
37 Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
38 )
39 }
40
41 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
42 StageSpec::identity(shape.inlets(), shape.outlets())
43 }
44
45 fn stage_spec_with_ports(
46 &self,
47 _shape: &Self::Shape,
48 inlets: Vec<AnyInlet>,
49 outlets: Vec<AnyOutlet>,
50 ) -> StageSpec {
51 StageSpec::identity(inlets, outlets)
52 }
53}
54
55#[derive(Clone)]
56pub struct MapStage<In: 'static, Out: 'static> {
57 f: Arc<dyn Fn(In) -> Out + Send + Sync>,
58 _marker: PhantomData<fn(In) -> Out>,
59}
60
61impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 f.debug_struct("MapStage")
64 .field("name", &"Map")
65 .finish_non_exhaustive()
66 }
67}
68
69impl<In: 'static, Out: 'static> MapStage<In, Out> {
70 #[must_use]
71 pub fn new<F>(f: F) -> Self
72 where
73 F: Fn(In) -> Out + Send + Sync + 'static,
74 {
75 Self {
76 f: Arc::new(f),
77 _marker: PhantomData,
78 }
79 }
80}
81
82impl<In, Out> GraphStage for MapStage<In, Out>
83where
84 In: Clone + Send + 'static,
85 Out: Clone + Send + 'static,
86{
87 type Shape = FlowShape<In, Out>;
88
89 fn name(&self) -> &str {
90 "Map"
91 }
92
93 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
94 let first_id = next_port_id_block(2);
95 FlowShape::new(
96 Inlet::with_arc_name(first_id, map_inlet_name()),
97 Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
98 )
99 }
100
101 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
102 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
103 }
104
105 fn stage_spec_with_ports(
106 &self,
107 _shape: &Self::Shape,
108 inlets: Vec<AnyInlet>,
109 outlets: Vec<AnyOutlet>,
110 ) -> StageSpec {
111 let f = Arc::clone(&self.f);
112 let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
113 let mapper = Arc::new(move |value: DatumValue| {
114 let value: In = downcast_datum(value, "map", || "Map.in")?;
115 Ok(datum(f(value)))
116 });
117 StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
118 }
119}
120
121#[derive(Clone, Debug)]
122pub struct Buffer<T: 'static> {
123 capacity: usize,
124 strategy: OverflowStrategy,
125 _marker: PhantomData<fn() -> T>,
126}
127
128impl<T: 'static> Buffer<T> {
129 #[must_use]
130 pub fn new(capacity: usize, strategy: OverflowStrategy) -> Self {
131 assert!(capacity > 0, "buffer capacity must be greater than zero");
132 Self {
133 capacity,
134 strategy,
135 _marker: PhantomData,
136 }
137 }
138}
139
140impl<T> GraphStage for Buffer<T>
141where
142 T: Clone + Send + 'static,
143{
144 type Shape = FlowShape<T, T>;
145
146 fn name(&self) -> &str {
147 "Buffer"
148 }
149
150 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
151 let first_id = next_port_id_block(2);
152 FlowShape::new(
153 Inlet::with_id(first_id, "Buffer.in"),
154 Outlet::with_id(first_id.offset(1), "Buffer.out"),
155 )
156 }
157
158 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
159 StageSpec::opaque(self.name(), shape.inlets(), shape.outlets())
163 .with_typed_cyclic(TypedCyclicOp::BufferPassthrough)
164 }
165
166 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
167 struct State<T> {
168 queue: VecDeque<T>,
169 upstream_closed: bool,
170 }
171
172 fn lock_state<T>(
173 state: &Arc<Mutex<State<T>>>,
174 ) -> StreamResult<std::sync::MutexGuard<'_, State<T>>> {
175 state
176 .lock()
177 .map_err(|_| StreamError::Failed("graph buffer state poisoned".into()))
178 }
179
180 fn enqueue<T>(
181 state: &mut State<T>,
182 value: T,
183 capacity: usize,
184 strategy: OverflowStrategy,
185 ) -> StreamResult<()> {
186 if state.queue.len() < capacity {
187 state.queue.push_back(value);
188 return Ok(());
189 }
190
191 match strategy {
192 OverflowStrategy::Backpressure | OverflowStrategy::Fail => Err(
193 StreamError::Failed(format!("Buffer overflow (max capacity was: {capacity})!")),
194 ),
195 OverflowStrategy::DropHead => {
196 let _ = state.queue.pop_front();
197 state.queue.push_back(value);
198 Ok(())
199 }
200 OverflowStrategy::DropTail => {
201 let _ = state.queue.pop_back();
202 state.queue.push_back(value);
203 Ok(())
204 }
205 OverflowStrategy::DropBuffer => {
206 state.queue.clear();
207 state.queue.push_back(value);
208 Ok(())
209 }
210 OverflowStrategy::DropNew => Ok(()),
211 }
212 }
213
214 fn drain_one<T>(
215 logic: &mut GraphStageLogic,
216 outlet: &Outlet<T>,
217 state: &Arc<Mutex<State<T>>>,
218 ) -> StreamResult<()>
219 where
220 T: Clone + Send + 'static,
221 {
222 if !logic.is_available(outlet) {
223 return Ok(());
224 }
225 let next = lock_state(state)?.queue.pop_front();
226 if let Some(value) = next {
227 logic.push(outlet, value)?;
228 }
229 Ok(())
230 }
231
232 fn maybe_pull<T>(
233 logic: &mut GraphStageLogic,
234 inlet: &Inlet<T>,
235 state: &Arc<Mutex<State<T>>>,
236 capacity: usize,
237 ) -> StreamResult<()>
238 where
239 T: 'static,
240 {
241 let can_pull = {
242 let state = lock_state(state)?;
243 !state.upstream_closed && state.queue.len() < capacity
244 };
245 if can_pull && !logic.has_been_pulled(inlet) {
246 logic.pull(inlet)?;
247 }
248 Ok(())
249 }
250
251 fn maybe_complete<T>(
252 logic: &mut GraphStageLogic,
253 outlet: &Outlet<T>,
254 state: &Arc<Mutex<State<T>>>,
255 ) -> StreamResult<()>
256 where
257 T: 'static,
258 {
259 let done = {
260 let state = lock_state(state)?;
261 state.upstream_closed && state.queue.is_empty()
262 };
263 if done && !logic.is_closed(outlet) {
264 logic.complete(outlet)?;
265 }
266 Ok(())
267 }
268
269 struct In<T: 'static> {
270 inlet: Inlet<T>,
271 outlet: Outlet<T>,
272 state: Arc<Mutex<State<T>>>,
273 capacity: usize,
274 strategy: OverflowStrategy,
275 }
276
277 impl<T> InHandler for In<T>
278 where
279 T: Clone + Send + 'static,
280 {
281 fn on_push(
282 &mut self,
283 logic: &mut GraphStageLogic,
284 _inlet: AnyInlet,
285 ) -> StreamResult<()> {
286 let value = logic.grab(&self.inlet)?;
287 {
288 let mut state = lock_state(&self.state)?;
289 enqueue(&mut state, value, self.capacity, self.strategy)?;
290 }
291 drain_one(logic, &self.outlet, &self.state)?;
292 maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
293 maybe_complete(logic, &self.outlet, &self.state)
294 }
295
296 fn on_upstream_finish(
297 &mut self,
298 logic: &mut GraphStageLogic,
299 _inlet: AnyInlet,
300 ) -> StreamResult<()> {
301 lock_state(&self.state)?.upstream_closed = true;
302 drain_one(logic, &self.outlet, &self.state)?;
303 maybe_complete(logic, &self.outlet, &self.state)
304 }
305 }
306
307 struct Out<T: 'static> {
308 inlet: Inlet<T>,
309 outlet: Outlet<T>,
310 state: Arc<Mutex<State<T>>>,
311 capacity: usize,
312 }
313
314 impl<T> OutHandler for Out<T>
315 where
316 T: Clone + Send + 'static,
317 {
318 fn on_pull(
319 &mut self,
320 logic: &mut GraphStageLogic,
321 _outlet: AnyOutlet,
322 ) -> StreamResult<()> {
323 drain_one(logic, &self.outlet, &self.state)?;
324 maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
325 maybe_complete(logic, &self.outlet, &self.state)
326 }
327
328 fn on_downstream_finish(
329 &mut self,
330 logic: &mut GraphStageLogic,
331 _outlet: AnyOutlet,
332 ) -> StreamResult<()> {
333 logic.complete_stage()
334 }
335 }
336
337 let state = Arc::new(Mutex::new(State {
338 queue: VecDeque::new(),
339 upstream_closed: false,
340 }));
341 let mut logic = GraphStageLogic::new(shape);
342 logic.pull(&shape.inlet()).unwrap();
343 logic
344 .set_handler(
345 &shape.inlet(),
346 Box::new(In {
347 inlet: shape.inlet(),
348 outlet: shape.outlet(),
349 state: Arc::clone(&state),
350 capacity: self.capacity,
351 strategy: self.strategy,
352 }),
353 )
354 .unwrap();
355 logic
356 .set_out_handler(
357 &shape.outlet(),
358 Box::new(Out {
359 inlet: shape.inlet(),
360 outlet: shape.outlet(),
361 state,
362 capacity: self.capacity,
363 }),
364 )
365 .unwrap();
366 logic
367 }
368}
369
370#[derive(Clone)]
371pub struct TakeWhile<T: 'static> {
372 predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
373 _marker: PhantomData<fn() -> T>,
374}
375
376impl<T: 'static> fmt::Debug for TakeWhile<T> {
377 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
378 f.debug_struct("TakeWhile")
379 .field("name", &"TakeWhile")
380 .finish_non_exhaustive()
381 }
382}
383
384impl<T: 'static> TakeWhile<T> {
385 #[must_use]
386 pub fn new<F>(predicate: F) -> Self
387 where
388 F: Fn(&T) -> bool + Send + Sync + 'static,
389 {
390 Self {
391 predicate: Arc::new(predicate),
392 _marker: PhantomData,
393 }
394 }
395}
396
397impl<T> GraphStage for TakeWhile<T>
398where
399 T: Clone + Send + 'static,
400{
401 type Shape = FlowShape<T, T>;
402
403 fn name(&self) -> &str {
404 "TakeWhile"
405 }
406
407 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
408 let first_id = next_port_id_block(2);
409 FlowShape::new(
410 Inlet::with_id(first_id, "TakeWhile.in"),
411 Outlet::with_id(first_id.offset(1), "TakeWhile.out"),
412 )
413 }
414
415 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
416 let predicate: Arc<dyn Fn(&T) -> bool + Send + Sync> = Arc::clone(&self.predicate);
421 StageSpec::opaque(self.name(), shape.inlets(), shape.outlets()).with_typed_cyclic(
422 TypedCyclicOp::TakeWhile(Arc::new(predicate) as Arc<dyn Any + Send + Sync>),
423 )
424 }
425
426 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
427 struct In<T: 'static> {
428 inlet: Inlet<T>,
429 outlet: Outlet<T>,
430 predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
431 }
432
433 impl<T> InHandler for In<T>
434 where
435 T: Clone + Send + 'static,
436 {
437 fn on_push(
438 &mut self,
439 logic: &mut GraphStageLogic,
440 _inlet: AnyInlet,
441 ) -> StreamResult<()> {
442 let value = logic.grab(&self.inlet)?;
443 if (self.predicate)(&value) {
444 logic.emit(&self.outlet, value)
445 } else if logic.is_closed(&self.outlet) {
446 Ok(())
447 } else {
448 logic.complete(&self.outlet)
449 }
450 }
451
452 fn on_upstream_finish(
453 &mut self,
454 logic: &mut GraphStageLogic,
455 _inlet: AnyInlet,
456 ) -> StreamResult<()> {
457 if logic.is_closed(&self.outlet) {
458 Ok(())
459 } else {
460 logic.complete(&self.outlet)
461 }
462 }
463 }
464
465 struct Out<T: 'static> {
466 inlet: Inlet<T>,
467 }
468
469 impl<T> OutHandler for Out<T>
470 where
471 T: Clone + Send + 'static,
472 {
473 fn on_pull(
474 &mut self,
475 logic: &mut GraphStageLogic,
476 _outlet: AnyOutlet,
477 ) -> StreamResult<()> {
478 if !logic.has_been_pulled(&self.inlet) {
479 logic.pull(&self.inlet)?;
480 }
481 Ok(())
482 }
483
484 fn on_downstream_finish(
485 &mut self,
486 logic: &mut GraphStageLogic,
487 _outlet: AnyOutlet,
488 ) -> StreamResult<()> {
489 logic.complete_stage()
490 }
491 }
492
493 let mut logic = GraphStageLogic::new(shape);
494 logic
495 .set_handler(
496 &shape.inlet(),
497 Box::new(In {
498 inlet: shape.inlet(),
499 outlet: shape.outlet(),
500 predicate: Arc::clone(&self.predicate),
501 }),
502 )
503 .unwrap();
504 logic
505 .set_out_handler(
506 &shape.outlet(),
507 Box::new(Out {
508 inlet: shape.inlet(),
509 }),
510 )
511 .unwrap();
512 logic
513 }
514}
515
516#[derive(Clone, Debug)]
517pub struct Broadcast<T: 'static> {
518 outputs: usize,
519 _marker: PhantomData<fn() -> T>,
520}
521
522impl<T: 'static> Broadcast<T> {
523 #[must_use]
524 pub fn new(outputs: usize) -> Self {
525 assert!(
526 outputs > 0,
527 "broadcast output count must be greater than zero"
528 );
529 Self {
530 outputs,
531 _marker: PhantomData,
532 }
533 }
534}
535
536impl<T> GraphStage for Broadcast<T>
537where
538 T: Clone + Send + 'static,
539{
540 type Shape = FanOutShape<T, T>;
541
542 fn name(&self) -> &str {
543 "Broadcast"
544 }
545
546 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
547 let inlet = allocator.inlet_arc(broadcast_inlet_name());
548 let outlets = (0..self.outputs)
549 .map(|index| allocator.outlet(format!("Broadcast.out{index}")))
550 .collect();
551 FanOutShape::new(inlet, outlets)
552 }
553
554 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
555 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
556 }
557
558 fn stage_spec_with_ports(
559 &self,
560 _shape: &Self::Shape,
561 inlets: Vec<AnyInlet>,
562 outlets: Vec<AnyOutlet>,
563 ) -> StageSpec {
564 StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
565 }
566}
567
568#[derive(Clone, Debug)]
569pub struct Balance<T: 'static> {
570 outputs: usize,
571 _marker: PhantomData<fn() -> T>,
572}
573
574impl<T: 'static> Balance<T> {
575 #[must_use]
576 pub fn new(outputs: usize) -> Self {
577 assert!(
578 outputs > 0,
579 "balance output count must be greater than zero"
580 );
581 Self {
582 outputs,
583 _marker: PhantomData,
584 }
585 }
586}
587
588impl<T> GraphStage for Balance<T>
589where
590 T: Clone + Send + 'static,
591{
592 type Shape = FanOutShape<T, T>;
593
594 fn name(&self) -> &str {
595 "Balance"
596 }
597
598 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
599 let inlet = allocator.inlet_arc(balance_inlet_name());
600 let outlets = (0..self.outputs)
601 .map(|index| allocator.outlet(format!("Balance.out{index}")))
602 .collect();
603 FanOutShape::new(inlet, outlets)
604 }
605
606 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
607 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
608 }
609
610 fn stage_spec_with_ports(
611 &self,
612 _shape: &Self::Shape,
613 inlets: Vec<AnyInlet>,
614 outlets: Vec<AnyOutlet>,
615 ) -> StageSpec {
616 StageSpec::balance(balance_stage_name(), inlets, outlets)
617 }
618}
619
620#[derive(Clone, Debug)]
621pub struct Merge<T: 'static> {
622 inputs: usize,
623 _marker: PhantomData<fn() -> T>,
624}
625
626impl<T: 'static> Merge<T> {
627 #[must_use]
628 pub fn new(inputs: usize) -> Self {
629 assert!(inputs > 0, "merge input count must be greater than zero");
630 Self {
631 inputs,
632 _marker: PhantomData,
633 }
634 }
635}
636
637impl<T> GraphStage for Merge<T>
638where
639 T: Clone + Send + 'static,
640{
641 type Shape = FanInShape<T, T>;
642
643 fn name(&self) -> &str {
644 "Merge"
645 }
646
647 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
648 let inlets = (0..self.inputs)
649 .map(|index| allocator.inlet(format!("Merge.in{index}")))
650 .collect();
651 FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
652 }
653
654 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
655 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
656 }
657
658 fn stage_spec_with_ports(
659 &self,
660 _shape: &Self::Shape,
661 inlets: Vec<AnyInlet>,
662 outlets: Vec<AnyOutlet>,
663 ) -> StageSpec {
664 StageSpec::merge(merge_stage_name(), inlets, outlets)
665 }
666}
667
668#[derive(Clone, Debug)]
669pub struct Concat<T: 'static> {
670 inputs: usize,
671 _marker: PhantomData<fn() -> T>,
672}
673
674impl<T: 'static> Concat<T> {
675 #[must_use]
676 pub fn new(inputs: usize) -> Self {
677 assert!(inputs > 1, "concat input count must be greater than one");
678 Self {
679 inputs,
680 _marker: PhantomData,
681 }
682 }
683}
684
685impl<T> GraphStage for Concat<T>
686where
687 T: Clone + Send + 'static,
688{
689 type Shape = FanInShape<T, T>;
690
691 fn name(&self) -> &str {
692 "Concat"
693 }
694
695 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
696 let inlets = (0..self.inputs)
697 .map(|index| allocator.inlet(format!("Concat.in{index}")))
698 .collect();
699 FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
700 }
701
702 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
703 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
704 }
705
706 fn stage_spec_with_ports(
707 &self,
708 _shape: &Self::Shape,
709 inlets: Vec<AnyInlet>,
710 outlets: Vec<AnyOutlet>,
711 ) -> StageSpec {
712 StageSpec::concat(concat_stage_name(), inlets, outlets)
713 }
714}
715
716#[derive(Clone, Debug)]
717pub struct OrElse<T: 'static> {
718 _marker: PhantomData<fn() -> T>,
719}
720
721impl<T: 'static> OrElse<T> {
722 #[must_use]
723 pub fn new() -> Self {
724 Self {
725 _marker: PhantomData,
726 }
727 }
728}
729
730impl<T: 'static> Default for OrElse<T> {
731 fn default() -> Self {
732 Self::new()
733 }
734}
735
736impl<T> GraphStage for OrElse<T>
737where
738 T: Clone + Send + 'static,
739{
740 type Shape = FanInShape<T, T>;
741
742 fn name(&self) -> &str {
743 "OrElse"
744 }
745
746 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
747 let inlets = vec![
748 allocator.inlet_arc(or_else_primary_name()),
749 allocator.inlet_arc(or_else_secondary_name()),
750 ];
751 FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
752 }
753
754 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
755 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
756 }
757
758 fn stage_spec_with_ports(
759 &self,
760 _shape: &Self::Shape,
761 inlets: Vec<AnyInlet>,
762 outlets: Vec<AnyOutlet>,
763 ) -> StageSpec {
764 StageSpec::or_else(or_else_stage_name(), inlets, outlets)
765 }
766}
767
768#[derive(Clone, Debug)]
769pub struct Interleave<T: 'static> {
770 inputs: usize,
771 segment_size: usize,
772 eager_close: bool,
773 _marker: PhantomData<fn() -> T>,
774}
775
776impl<T: 'static> Interleave<T> {
777 #[must_use]
778 pub fn new(inputs: usize, segment_size: usize) -> Self {
779 Self::new_with_eager_close(inputs, segment_size, false)
780 }
781
782 #[must_use]
783 pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
784 assert!(
785 inputs > 1,
786 "interleave input count must be greater than one"
787 );
788 assert!(
789 segment_size > 0,
790 "interleave segment size must be greater than zero"
791 );
792 Self {
793 inputs,
794 segment_size,
795 eager_close,
796 _marker: PhantomData,
797 }
798 }
799}
800
801impl<T> GraphStage for Interleave<T>
802where
803 T: Clone + Send + 'static,
804{
805 type Shape = FanInShape<T, T>;
806
807 fn name(&self) -> &str {
808 "Interleave"
809 }
810
811 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
812 let inlets = (0..self.inputs)
813 .map(|index| allocator.inlet(format!("Interleave.in{index}")))
814 .collect();
815 FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
816 }
817
818 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
819 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
820 }
821
822 fn stage_spec_with_ports(
823 &self,
824 _shape: &Self::Shape,
825 inlets: Vec<AnyInlet>,
826 outlets: Vec<AnyOutlet>,
827 ) -> StageSpec {
828 StageSpec::interleave(
829 interleave_stage_name(),
830 inlets,
831 outlets,
832 self.segment_size,
833 self.eager_close,
834 )
835 }
836}
837
838#[derive(Clone, Debug)]
839pub struct MergePreferred<T: 'static> {
840 secondary_ports: usize,
841 _marker: PhantomData<fn() -> T>,
842}
843
844impl<T: 'static> MergePreferred<T> {
845 #[must_use]
846 pub fn new(secondary_ports: usize) -> Self {
847 assert!(
848 secondary_ports > 0,
849 "merge-preferred secondary input count must be greater than zero"
850 );
851 Self {
852 secondary_ports,
853 _marker: PhantomData,
854 }
855 }
856}
857
858impl<T> GraphStage for MergePreferred<T>
859where
860 T: Clone + Send + 'static,
861{
862 type Shape = MergePreferredShape<T>;
863
864 fn name(&self) -> &str {
865 "MergePreferred"
866 }
867
868 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
869 let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
870 let secondary = (0..self.secondary_ports)
871 .map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
872 .collect();
873 MergePreferredShape::new(
874 preferred,
875 secondary,
876 allocator.outlet_arc(merge_preferred_outlet_name()),
877 )
878 }
879
880 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
881 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
882 }
883
884 fn stage_spec_with_ports(
885 &self,
886 _shape: &Self::Shape,
887 inlets: Vec<AnyInlet>,
888 outlets: Vec<AnyOutlet>,
889 ) -> StageSpec {
890 StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
891 }
892}
893
894#[derive(Clone, Debug)]
895pub struct MergePrioritized<T: 'static> {
896 weights: Vec<usize>,
897 _marker: PhantomData<fn() -> T>,
898}
899
900impl<T: 'static> MergePrioritized<T> {
901 #[must_use]
902 pub fn new(weights: Vec<usize>) -> Self {
903 assert!(!weights.is_empty(), "prioritized merge must have inputs");
904 assert!(
905 weights.iter().all(|weight| *weight > 0),
906 "prioritized merge weights must be greater than zero"
907 );
908 Self {
909 weights,
910 _marker: PhantomData,
911 }
912 }
913}
914
915impl<T> GraphStage for MergePrioritized<T>
916where
917 T: Clone + Send + 'static,
918{
919 type Shape = FanInShape<T, T>;
920
921 fn name(&self) -> &str {
922 "MergePrioritized"
923 }
924
925 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
926 let inlets = (0..self.weights.len())
927 .map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
928 .collect();
929 FanInShape::new(
930 inlets,
931 allocator.outlet_arc(merge_prioritized_outlet_name()),
932 )
933 }
934
935 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
936 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
937 }
938
939 fn stage_spec_with_ports(
940 &self,
941 _shape: &Self::Shape,
942 inlets: Vec<AnyInlet>,
943 outlets: Vec<AnyOutlet>,
944 ) -> StageSpec {
945 StageSpec::merge_prioritized(
946 merge_prioritized_stage_name(),
947 inlets,
948 outlets,
949 self.weights.clone(),
950 )
951 }
952}
953
954#[derive(Clone, Debug)]
955pub struct Zip<Left: 'static, Right: 'static> {
956 _marker: PhantomData<fn() -> (Left, Right)>,
957}
958
959impl<Left: 'static, Right: 'static> Zip<Left, Right> {
960 #[must_use]
961 pub fn new() -> Self {
962 Self {
963 _marker: PhantomData,
964 }
965 }
966}
967
968impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
969 fn default() -> Self {
970 Self::new()
971 }
972}
973
974impl<Left, Right> GraphStage for Zip<Left, Right>
975where
976 Left: Clone + Send + 'static,
977 Right: Clone + Send + 'static,
978{
979 type Shape = ZipShape<Left, Right>;
980
981 fn name(&self) -> &str {
982 "Zip"
983 }
984
985 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
986 let first_id = next_port_id_block(3);
987 ZipShape::new(
988 Inlet::with_arc_name(first_id, zip_in0_name()),
989 Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
990 Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
991 )
992 }
993
994 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
995 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
996 }
997
998 fn stage_spec_with_ports(
999 &self,
1000 _shape: &Self::Shape,
1001 inlets: Vec<AnyInlet>,
1002 outlets: Vec<AnyOutlet>,
1003 ) -> StageSpec {
1004 let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
1005 let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
1006 let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
1007 Ok(datum((left, right)))
1008 });
1009 StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
1010 }
1011}
1012
1013#[derive(Clone, Debug)]
1014pub struct MergeSorted<T: 'static> {
1015 _marker: PhantomData<fn() -> T>,
1016}
1017
1018impl<T: 'static> MergeSorted<T> {
1019 #[must_use]
1020 pub fn new() -> Self {
1021 Self {
1022 _marker: PhantomData,
1023 }
1024 }
1025}
1026
1027impl<T: 'static> Default for MergeSorted<T> {
1028 fn default() -> Self {
1029 Self::new()
1030 }
1031}
1032
1033impl<T> GraphStage for MergeSorted<T>
1034where
1035 T: Clone + Ord + Send + 'static,
1036{
1037 type Shape = FanInShape<T, T>;
1038
1039 fn name(&self) -> &str {
1040 "MergeSorted"
1041 }
1042
1043 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1044 let inlets = vec![
1045 allocator.inlet("MergeSorted.in0"),
1046 allocator.inlet("MergeSorted.in1"),
1047 ];
1048 FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
1049 }
1050
1051 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1052 let compare = Arc::new(
1053 move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
1054 let a_t: &T = a
1055 .as_any_ref()
1056 .downcast_ref::<T>()
1057 .expect("merge-sorted compare: wrong element type");
1058 let b_t: &T = b
1059 .as_any_ref()
1060 .downcast_ref::<T>()
1061 .expect("merge-sorted compare: wrong element type");
1062 a_t.cmp(b_t)
1063 },
1064 );
1065 StageSpec::merge_sorted(
1066 Arc::from(self.name()),
1067 shape.inlets(),
1068 shape.outlets(),
1069 compare,
1070 )
1071 }
1072
1073 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1074 struct State<T> {
1075 left: VecDeque<T>,
1076 right: VecDeque<T>,
1077 left_closed: bool,
1078 right_closed: bool,
1079 pending: VecDeque<T>,
1080 }
1081
1082 impl<T> Default for State<T> {
1083 fn default() -> Self {
1084 Self {
1085 left: VecDeque::new(),
1086 right: VecDeque::new(),
1087 left_closed: false,
1088 right_closed: false,
1089 pending: VecDeque::new(),
1090 }
1091 }
1092 }
1093
1094 fn maybe_queue_output<T>(state: &mut State<T>) -> bool
1095 where
1096 T: Clone + Ord,
1097 {
1098 let next = match (state.left.front(), state.right.front()) {
1099 (Some(left), Some(right)) => {
1100 if left <= right {
1101 state.left.pop_front()
1102 } else {
1103 state.right.pop_front()
1104 }
1105 }
1106 (Some(_), None) if state.right_closed => state.left.pop_front(),
1107 (None, Some(_)) if state.left_closed => state.right.pop_front(),
1108 _ => None,
1109 };
1110 if let Some(value) = next {
1111 state.pending.push_back(value);
1112 true
1113 } else {
1114 false
1115 }
1116 }
1117
1118 fn maybe_complete<T>(
1119 logic: &mut GraphStageLogic,
1120 outlet: &Outlet<T>,
1121 state: &State<T>,
1122 ) -> StreamResult<()>
1123 where
1124 T: Clone + Send + 'static,
1125 {
1126 if state.left_closed
1127 && state.right_closed
1128 && state.left.is_empty()
1129 && state.right.is_empty()
1130 && state.pending.is_empty()
1131 && !logic.is_closed(outlet)
1132 {
1133 logic.complete(outlet)?;
1134 }
1135 Ok(())
1136 }
1137
1138 fn maybe_pull<T>(
1139 logic: &mut GraphStageLogic,
1140 left: &Inlet<T>,
1141 right: &Inlet<T>,
1142 state: &State<T>,
1143 ) -> StreamResult<()>
1144 where
1145 T: Clone + Send + 'static,
1146 {
1147 if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
1148 logic.pull(left)?;
1149 }
1150 if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
1151 logic.pull(right)?;
1152 }
1153 Ok(())
1154 }
1155
1156 fn maybe_drain<T>(
1157 logic: &mut GraphStageLogic,
1158 outlet: &Outlet<T>,
1159 state: &Arc<Mutex<State<T>>>,
1160 ) -> StreamResult<()>
1161 where
1162 T: Clone + Ord + Send + 'static,
1163 {
1164 let next = if logic.is_available(outlet) {
1165 state
1166 .lock()
1167 .expect("merge-sorted state poisoned")
1168 .pending
1169 .pop_front()
1170 } else {
1171 None
1172 };
1173 if let Some(value) = next {
1174 logic.push(outlet, value)?;
1175 }
1176 Ok(())
1177 }
1178
1179 struct In<T: 'static> {
1180 inlet_id: PortId,
1181 left: Inlet<T>,
1182 right: Inlet<T>,
1183 outlet: Outlet<T>,
1184 state: Arc<Mutex<State<T>>>,
1185 }
1186
1187 impl<T> InHandler for In<T>
1188 where
1189 T: Clone + Ord + Send + 'static,
1190 {
1191 fn on_push(
1192 &mut self,
1193 logic: &mut GraphStageLogic,
1194 _inlet: AnyInlet,
1195 ) -> StreamResult<()> {
1196 let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1197 downcast_datum(value, "grab", || {
1198 format!("inlet#{}", self.inlet_id.as_usize())
1199 })
1200 })?;
1201 {
1202 let mut state = self.state.lock().expect("merge-sorted state poisoned");
1203 if self.inlet_id == self.left.id() {
1204 state.left.push_back(value);
1205 } else {
1206 state.right.push_back(value);
1207 }
1208 while maybe_queue_output(&mut state) {}
1209 }
1210 maybe_drain(logic, &self.outlet, &self.state)?;
1211 let state = self.state.lock().expect("merge-sorted state poisoned");
1212 maybe_complete(logic, &self.outlet, &state)?;
1213 maybe_pull(logic, &self.left, &self.right, &state)
1214 }
1215
1216 fn on_upstream_finish(
1217 &mut self,
1218 logic: &mut GraphStageLogic,
1219 _inlet: AnyInlet,
1220 ) -> StreamResult<()> {
1221 {
1222 let mut state = self.state.lock().expect("merge-sorted state poisoned");
1223 if self.inlet_id == self.left.id() {
1224 state.left_closed = true;
1225 } else {
1226 state.right_closed = true;
1227 }
1228 while maybe_queue_output(&mut state) {}
1229 }
1230 maybe_drain(logic, &self.outlet, &self.state)?;
1231 let state = self.state.lock().expect("merge-sorted state poisoned");
1232 maybe_complete(logic, &self.outlet, &state)?;
1233 maybe_pull(logic, &self.left, &self.right, &state)
1234 }
1235 }
1236
1237 struct Out<T: 'static> {
1238 left: Inlet<T>,
1239 right: Inlet<T>,
1240 outlet: Outlet<T>,
1241 state: Arc<Mutex<State<T>>>,
1242 }
1243
1244 impl<T> OutHandler for Out<T>
1245 where
1246 T: Clone + Ord + Send + 'static,
1247 {
1248 fn on_pull(
1249 &mut self,
1250 logic: &mut GraphStageLogic,
1251 _outlet: AnyOutlet,
1252 ) -> StreamResult<()> {
1253 maybe_drain(logic, &self.outlet, &self.state)?;
1254 let state = self.state.lock().expect("merge-sorted state poisoned");
1255 maybe_complete(logic, &self.outlet, &state)?;
1256 maybe_pull(logic, &self.left, &self.right, &state)
1257 }
1258 }
1259
1260 let state = Arc::new(Mutex::new(State::<T>::default()));
1261 let left = shape.inlet(0).expect("merge-sorted left inlet");
1262 let right = shape.inlet(1).expect("merge-sorted right inlet");
1263 let outlet = shape.outlet();
1264 let mut logic = GraphStageLogic::new(shape);
1265 logic
1266 .set_handler(
1267 &left,
1268 Box::new(In {
1269 inlet_id: left.id(),
1270 left: left.clone(),
1271 right: right.clone(),
1272 outlet: outlet.clone(),
1273 state: Arc::clone(&state),
1274 }),
1275 )
1276 .unwrap();
1277 logic
1278 .set_handler(
1279 &right,
1280 Box::new(In {
1281 inlet_id: right.id(),
1282 left: left.clone(),
1283 right: right.clone(),
1284 outlet: outlet.clone(),
1285 state: Arc::clone(&state),
1286 }),
1287 )
1288 .unwrap();
1289 logic
1290 .set_out_handler(
1291 &outlet.clone(),
1292 Box::new(Out {
1293 left,
1294 right,
1295 outlet: outlet.clone(),
1296 state,
1297 }),
1298 )
1299 .unwrap();
1300 logic
1301 }
1302}
1303
1304#[derive(Clone)]
1305pub struct MergeSequence<T: 'static> {
1306 inputs: usize,
1307 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1308 _marker: PhantomData<fn() -> T>,
1309}
1310
1311impl<T: 'static> fmt::Debug for MergeSequence<T> {
1312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1313 f.debug_struct("MergeSequence")
1314 .field("inputs", &self.inputs)
1315 .finish_non_exhaustive()
1316 }
1317}
1318
1319impl<T: 'static> MergeSequence<T> {
1320 #[must_use]
1321 pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
1322 where
1323 F: Fn(&T) -> u64 + Send + Sync + 'static,
1324 {
1325 assert!(
1326 inputs > 1,
1327 "merge sequence input count must be greater than one"
1328 );
1329 Self {
1330 inputs,
1331 extract_sequence: Arc::new(extract_sequence),
1332 _marker: PhantomData,
1333 }
1334 }
1335}
1336
1337impl<T> GraphStage for MergeSequence<T>
1338where
1339 T: Clone + Send + 'static,
1340{
1341 type Shape = FanInShape<T, T>;
1342
1343 fn name(&self) -> &str {
1344 "MergeSequence"
1345 }
1346
1347 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1348 let inlets = (0..self.inputs)
1349 .map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
1350 .collect();
1351 FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
1352 }
1353
1354 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1355 let extract = Arc::clone(&self.extract_sequence);
1356 let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
1357 let t: &T = dv
1358 .as_any_ref()
1359 .downcast_ref::<T>()
1360 .expect("merge-sequence extract: wrong element type");
1361 extract(t)
1362 });
1363 let typed_extract_fn = Arc::clone(&self.extract_sequence);
1366 let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
1367 let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
1368 StageSpec::merge_sequence(
1369 Arc::from(self.name()),
1370 shape.inlets(),
1371 shape.outlets(),
1372 self.inputs,
1373 extract_sequence,
1374 typed_extract,
1375 )
1376 }
1377
1378 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1379 #[derive(Clone)]
1380 struct Pending<T> {
1381 sequence: u64,
1382 elem: T,
1383 }
1384
1385 struct State<T> {
1386 next_sequence: u64,
1387 pending: Vec<Pending<T>>,
1388 completed: usize,
1389 pending_output: VecDeque<T>,
1390 }
1391
1392 fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
1393 where
1394 T: Clone + Send + 'static,
1395 {
1396 while let Some(index) = state
1397 .pending
1398 .iter()
1399 .position(|item| item.sequence == state.next_sequence)
1400 {
1401 let item = state.pending.remove(index);
1402 if state
1403 .pending
1404 .iter()
1405 .any(|other| other.sequence == state.next_sequence)
1406 {
1407 return Err(StreamError::Failed(format!(
1408 "duplicate sequence {} on merge sequence",
1409 state.next_sequence
1410 )));
1411 }
1412 state.pending_output.push_back(item.elem);
1413 state.next_sequence += 1;
1414 }
1415 Ok(())
1416 }
1417
1418 struct In<T: 'static> {
1419 inlet_id: PortId,
1420 inlet_index: usize,
1421 inlet: Inlet<T>,
1422 all_inlets: Vec<Inlet<T>>,
1423 outlet: Outlet<T>,
1424 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1425 state: Arc<Mutex<State<T>>>,
1426 }
1427
1428 impl<T> InHandler for In<T>
1429 where
1430 T: Clone + Send + 'static,
1431 {
1432 fn on_push(
1433 &mut self,
1434 logic: &mut GraphStageLogic,
1435 _inlet: AnyInlet,
1436 ) -> StreamResult<()> {
1437 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1438 downcast_datum(value, "grab", || {
1439 format!("inlet#{}", self.inlet_id.as_usize())
1440 })
1441 })?;
1442 {
1443 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1444 let sequence = (self.extract_sequence)(&elem);
1445 if sequence < state.next_sequence {
1446 return Err(StreamError::Failed(format!(
1447 "sequence regression from {} to {} on port {}",
1448 state.next_sequence, sequence, self.inlet_index
1449 )));
1450 }
1451 state.pending.push(Pending { sequence, elem });
1452 try_emit_pending(&mut state)?;
1453 }
1454 let next = if logic.is_available(&self.outlet) {
1455 self.state
1456 .lock()
1457 .expect("merge-sequence state poisoned")
1458 .pending_output
1459 .pop_front()
1460 } else {
1461 None
1462 };
1463 if let Some(value) = next {
1464 logic.push(&self.outlet, value)?;
1465 }
1466 let state = self.state.lock().expect("merge-sequence state poisoned");
1467 if state.completed == self.all_inlets.len()
1468 && state.pending.is_empty()
1469 && state.pending_output.is_empty()
1470 {
1471 logic.complete(&self.outlet)?;
1472 } else if logic.is_available(&self.outlet)
1473 && state.pending_output.is_empty()
1474 && state.pending.len() + state.completed == self.all_inlets.len()
1475 {
1476 return Err(StreamError::Failed(format!(
1477 "expected sequence {}, but all input ports have pushed or are complete",
1478 state.next_sequence
1479 )));
1480 }
1481 if !logic.has_been_pulled(&self.inlet) {
1482 logic.pull(&self.inlet)?;
1483 }
1484 Ok(())
1485 }
1486
1487 fn on_upstream_finish(
1488 &mut self,
1489 logic: &mut GraphStageLogic,
1490 _inlet: AnyInlet,
1491 ) -> StreamResult<()> {
1492 {
1493 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1494 state.completed += 1;
1495 }
1496 let state = self.state.lock().expect("merge-sequence state poisoned");
1497 if state.completed == self.all_inlets.len()
1498 && state.pending.is_empty()
1499 && state.pending_output.is_empty()
1500 {
1501 logic.complete(&self.outlet)?;
1502 } else if logic.is_available(&self.outlet)
1503 && state.pending_output.is_empty()
1504 && state.pending.len() + state.completed == self.all_inlets.len()
1505 {
1506 return Err(StreamError::Failed(format!(
1507 "expected sequence {}, but all input ports have pushed or are complete",
1508 state.next_sequence
1509 )));
1510 }
1511 Ok(())
1512 }
1513 }
1514
1515 struct Out<T: 'static> {
1516 inlets: Vec<Inlet<T>>,
1517 outlet: Outlet<T>,
1518 state: Arc<Mutex<State<T>>>,
1519 }
1520
1521 impl<T> OutHandler for Out<T>
1522 where
1523 T: Clone + Send + 'static,
1524 {
1525 fn on_pull(
1526 &mut self,
1527 logic: &mut GraphStageLogic,
1528 _outlet: AnyOutlet,
1529 ) -> StreamResult<()> {
1530 let next = self
1531 .state
1532 .lock()
1533 .expect("merge-sequence state poisoned")
1534 .pending_output
1535 .pop_front();
1536 if let Some(value) = next {
1537 logic.push(&self.outlet, value)?;
1538 } else {
1539 let state = self.state.lock().expect("merge-sequence state poisoned");
1540 if state.completed == self.inlets.len() && state.pending.is_empty() {
1541 logic.complete(&self.outlet)?;
1542 } else if state.pending.len() + state.completed == self.inlets.len() {
1543 return Err(StreamError::Failed(format!(
1544 "expected sequence {}, but all input ports have pushed or are complete",
1545 state.next_sequence
1546 )));
1547 }
1548 }
1549 for inlet in &self.inlets {
1550 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1551 logic.pull(inlet)?;
1552 }
1553 }
1554 Ok(())
1555 }
1556 }
1557
1558 let inlets = shape.inlets_vec();
1559 let outlet = shape.outlet();
1560 let state = Arc::new(Mutex::new(State {
1561 next_sequence: 0,
1562 pending: Vec::new(),
1563 completed: 0,
1564 pending_output: VecDeque::new(),
1565 }));
1566 let mut logic = GraphStageLogic::new(shape);
1567 for (index, inlet) in inlets.iter().cloned().enumerate() {
1568 logic
1569 .set_handler(
1570 &inlet.clone(),
1571 Box::new(In {
1572 inlet_id: inlet.id(),
1573 inlet_index: index,
1574 inlet: inlet.clone(),
1575 all_inlets: inlets.clone(),
1576 outlet: outlet.clone(),
1577 extract_sequence: Arc::clone(&self.extract_sequence),
1578 state: Arc::clone(&state),
1579 }),
1580 )
1581 .unwrap();
1582 }
1583 logic
1584 .set_out_handler(
1585 &outlet.clone(),
1586 Box::new(Out {
1587 inlets,
1588 outlet: outlet.clone(),
1589 state,
1590 }),
1591 )
1592 .unwrap();
1593 logic
1594 }
1595}
1596
1597#[derive(Clone, Debug)]
1598pub struct MergeLatest<T: 'static> {
1599 inputs: usize,
1600 eager_complete: bool,
1601 _marker: PhantomData<fn() -> T>,
1602}
1603
1604impl<T: 'static> MergeLatest<T> {
1605 #[must_use]
1606 pub fn new(inputs: usize, eager_complete: bool) -> Self {
1607 assert!(
1608 inputs > 0,
1609 "merge-latest input count must be greater than zero"
1610 );
1611 Self {
1612 inputs,
1613 eager_complete,
1614 _marker: PhantomData,
1615 }
1616 }
1617}
1618
1619impl<T> GraphStage for MergeLatest<T>
1620where
1621 T: Clone + Send + 'static,
1622{
1623 type Shape = FanInShape<T, Vec<T>>;
1624
1625 fn name(&self) -> &str {
1626 "MergeLatest"
1627 }
1628
1629 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1630 let inlets = (0..self.inputs)
1631 .map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
1632 .collect();
1633 FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
1634 }
1635
1636 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1637 let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
1638 let snapshot: Vec<T> = values
1639 .iter()
1640 .map(|dv| {
1641 dv.as_any_ref()
1642 .downcast_ref::<T>()
1643 .cloned()
1644 .expect("merge-latest snapshot: wrong element type")
1645 })
1646 .collect();
1647 datum(snapshot)
1648 });
1649 #[allow(clippy::type_complexity)]
1652 let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
1653 Arc::new(move |slots: &[Option<T>]| {
1654 slots
1655 .iter()
1656 .map(|s| {
1657 s.clone()
1658 .expect("merge-latest typed snapshot: slot is None")
1659 })
1660 .collect()
1661 });
1662 let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
1663 StageSpec::merge_latest(
1664 Arc::from(self.name()),
1665 shape.inlets(),
1666 shape.outlets(),
1667 self.inputs,
1668 self.eager_complete,
1669 build_snapshot,
1670 typed_snapshot,
1671 )
1672 }
1673
1674 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1675 struct State<T> {
1676 latest: Vec<Option<T>>,
1677 seen: usize,
1678 completed: usize,
1679 pending: VecDeque<Vec<T>>,
1680 eager_complete: bool,
1681 }
1682
1683 struct In<T: 'static> {
1684 inlet_id: PortId,
1685 inlet_index: usize,
1686 inlet: Inlet<T>,
1687 all_inlets: Vec<Inlet<T>>,
1688 outlet: Outlet<Vec<T>>,
1689 state: Arc<Mutex<State<T>>>,
1690 }
1691
1692 impl<T> InHandler for In<T>
1693 where
1694 T: Clone + Send + 'static,
1695 {
1696 fn on_push(
1697 &mut self,
1698 logic: &mut GraphStageLogic,
1699 _inlet: AnyInlet,
1700 ) -> StreamResult<()> {
1701 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1702 downcast_datum(value, "grab", || {
1703 format!("inlet#{}", self.inlet_id.as_usize())
1704 })
1705 })?;
1706 {
1707 let mut state = self.state.lock().expect("merge-latest state poisoned");
1708 if state.latest[self.inlet_index].is_none() {
1709 state.seen += 1;
1710 }
1711 state.latest[self.inlet_index] = Some(elem);
1712 if state.seen == state.latest.len() {
1713 let snapshot = state
1714 .latest
1715 .iter()
1716 .map(|item| item.clone().expect("merge-latest seen"))
1717 .collect();
1718 state.pending.push_back(snapshot);
1719 }
1720 }
1721 let next = if logic.is_available(&self.outlet) {
1722 self.state
1723 .lock()
1724 .expect("merge-latest state poisoned")
1725 .pending
1726 .pop_front()
1727 } else {
1728 None
1729 };
1730 if let Some(value) = next {
1731 logic.push(&self.outlet, value)?;
1732 }
1733 if !logic.has_been_pulled(&self.inlet) {
1734 logic.pull(&self.inlet)?;
1735 }
1736 Ok(())
1737 }
1738
1739 fn on_upstream_finish(
1740 &mut self,
1741 logic: &mut GraphStageLogic,
1742 _inlet: AnyInlet,
1743 ) -> StreamResult<()> {
1744 let state = {
1745 let mut state = self.state.lock().expect("merge-latest state poisoned");
1746 state.completed += 1;
1747 (
1748 state.completed == self.all_inlets.len(),
1749 state.eager_complete,
1750 state.pending.is_empty(),
1751 )
1752 };
1753 if state.0 || (state.1 && state.2) {
1754 logic.complete(&self.outlet)?;
1755 }
1756 Ok(())
1757 }
1758 }
1759
1760 struct Out<T: 'static> {
1761 inlets: Vec<Inlet<T>>,
1762 outlet: Outlet<Vec<T>>,
1763 state: Arc<Mutex<State<T>>>,
1764 }
1765
1766 impl<T> OutHandler for Out<T>
1767 where
1768 T: Clone + Send + 'static,
1769 {
1770 fn on_pull(
1771 &mut self,
1772 logic: &mut GraphStageLogic,
1773 _outlet: AnyOutlet,
1774 ) -> StreamResult<()> {
1775 let next = self
1776 .state
1777 .lock()
1778 .expect("merge-latest state poisoned")
1779 .pending
1780 .pop_front();
1781 if let Some(value) = next {
1782 logic.push(&self.outlet, value)?;
1783 } else {
1784 let state = self.state.lock().expect("merge-latest state poisoned");
1785 if state.completed == self.inlets.len()
1786 || (state.eager_complete && state.completed > 0)
1787 {
1788 logic.complete(&self.outlet)?;
1789 }
1790 }
1791 for inlet in &self.inlets {
1792 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1793 logic.pull(inlet)?;
1794 }
1795 }
1796 Ok(())
1797 }
1798 }
1799
1800 let inlets = shape.inlets_vec();
1801 let outlet = shape.outlet();
1802 let state = Arc::new(Mutex::new(State {
1803 latest: vec![None; inlets.len()],
1804 seen: 0,
1805 completed: 0,
1806 pending: VecDeque::new(),
1807 eager_complete: self.eager_complete,
1808 }));
1809 let mut logic = GraphStageLogic::new(shape);
1810 for (index, inlet) in inlets.iter().cloned().enumerate() {
1811 logic
1812 .set_handler(
1813 &inlet.clone(),
1814 Box::new(In {
1815 inlet_id: inlet.id(),
1816 inlet_index: index,
1817 inlet: inlet.clone(),
1818 all_inlets: inlets.clone(),
1819 outlet: outlet.clone(),
1820 state: Arc::clone(&state),
1821 }),
1822 )
1823 .unwrap();
1824 }
1825 logic
1826 .set_out_handler(
1827 &outlet.clone(),
1828 Box::new(Out {
1829 inlets,
1830 outlet: outlet.clone(),
1831 state,
1832 }),
1833 )
1834 .unwrap();
1835 logic
1836 }
1837}
1838
1839#[derive(Clone)]
1840pub struct Partition<T: 'static> {
1841 outputs: usize,
1842 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1843 eager_cancel: bool,
1844 _marker: PhantomData<fn() -> T>,
1845}
1846
1847impl<T: 'static> fmt::Debug for Partition<T> {
1848 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1849 f.debug_struct("Partition")
1850 .field("outputs", &self.outputs)
1851 .field("eager_cancel", &self.eager_cancel)
1852 .finish_non_exhaustive()
1853 }
1854}
1855
1856impl<T: 'static> Partition<T> {
1857 #[must_use]
1858 pub fn new<F>(outputs: usize, partitioner: F) -> Self
1859 where
1860 F: Fn(&T) -> usize + Send + Sync + 'static,
1861 {
1862 Self::new_with_eager_cancel(outputs, partitioner, false)
1863 }
1864
1865 #[must_use]
1866 pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
1867 where
1868 F: Fn(&T) -> usize + Send + Sync + 'static,
1869 {
1870 assert!(
1871 outputs > 0,
1872 "partition output count must be greater than zero"
1873 );
1874 Self {
1875 outputs,
1876 partitioner: Arc::new(partitioner),
1877 eager_cancel,
1878 _marker: PhantomData,
1879 }
1880 }
1881}
1882
1883impl<T> GraphStage for Partition<T>
1884where
1885 T: Clone + Send + 'static,
1886{
1887 type Shape = FanOutShape<T, T>;
1888
1889 fn name(&self) -> &str {
1890 "Partition"
1891 }
1892
1893 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1894 let inlet = allocator.inlet("Partition.in");
1895 let outlets = (0..self.outputs)
1896 .map(|index| allocator.outlet(format!("Partition.out{index}")))
1897 .collect();
1898 FanOutShape::new(inlet, outlets)
1899 }
1900
1901 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1902 let partitioner_clone = Arc::clone(&self.partitioner);
1903 let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
1904 let t: &T = dv
1905 .as_any_ref()
1906 .downcast_ref::<T>()
1907 .expect("partition: wrong element type");
1908 partitioner_clone(t)
1909 });
1910 let typed_partitioner_fn: Arc<dyn Fn(&T) -> usize + Send + Sync> =
1911 Arc::clone(&self.partitioner);
1912 let typed_partitioner: Arc<StageTypedPartitionFn> = Arc::new(typed_partitioner_fn);
1913 StageSpec::partition(
1914 Arc::from(self.name()),
1915 shape.inlets(),
1916 shape.outlets(),
1917 self.outputs,
1918 partitioner,
1919 typed_partitioner,
1920 self.eager_cancel,
1921 )
1922 }
1923
1924 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1925 struct State<T> {
1926 pending: Option<(usize, T)>,
1927 upstream_closed: bool,
1928 live_outlets: usize,
1929 cancelled: Vec<bool>,
1930 eager_cancel: bool,
1931 }
1932
1933 fn any_live_demand<T>(
1934 logic: &GraphStageLogic,
1935 outlets: &[Outlet<T>],
1936 cancelled: &[bool],
1937 ) -> bool
1938 where
1939 T: Clone + Send + 'static,
1940 {
1941 outlets
1942 .iter()
1943 .enumerate()
1944 .any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
1945 }
1946
1947 struct In<T: 'static> {
1948 inlet_id: PortId,
1949 inlet: Inlet<T>,
1950 outlets: Vec<Outlet<T>>,
1951 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1952 state: Arc<Mutex<State<T>>>,
1953 }
1954
1955 impl<T> InHandler for In<T>
1956 where
1957 T: Clone + Send + 'static,
1958 {
1959 fn on_push(
1960 &mut self,
1961 logic: &mut GraphStageLogic,
1962 _inlet: AnyInlet,
1963 ) -> StreamResult<()> {
1964 let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1965 downcast_datum(value, "grab", || {
1966 format!("inlet#{}", self.inlet_id.as_usize())
1967 })
1968 })?;
1969 let idx = (self.partitioner)(&item);
1970 if idx >= self.outlets.len() {
1971 return Err(StreamError::Failed(format!(
1972 "partitioner returned out-of-bounds index {idx} for {} outputs",
1973 self.outlets.len()
1974 )));
1975 }
1976 let mut pull_again = false;
1977 {
1978 let mut state = self.state.lock().expect("partition state poisoned");
1979 if state.cancelled[idx] {
1980 pull_again = !state.upstream_closed
1981 && any_live_demand(logic, &self.outlets, &state.cancelled);
1982 } else if logic.is_available(&self.outlets[idx]) {
1983 logic.push(&self.outlets[idx], item)?;
1984 pull_again = !state.upstream_closed
1985 && any_live_demand(logic, &self.outlets, &state.cancelled);
1986 } else {
1987 state.pending = Some((idx, item));
1988 }
1989 }
1990 if pull_again && !logic.has_been_pulled(&self.inlet) {
1991 logic.pull(&self.inlet)?;
1992 }
1993 Ok(())
1994 }
1995
1996 fn on_upstream_finish(
1997 &mut self,
1998 logic: &mut GraphStageLogic,
1999 _inlet: AnyInlet,
2000 ) -> StreamResult<()> {
2001 let complete_now = {
2002 let mut state = self.state.lock().expect("partition state poisoned");
2003 state.upstream_closed = true;
2004 state.pending.is_none()
2005 };
2006 if complete_now {
2007 for outlet in &self.outlets {
2008 if !logic.is_closed(outlet) {
2009 logic.complete(outlet)?;
2010 }
2011 }
2012 }
2013 Ok(())
2014 }
2015 }
2016
2017 struct Out<T: 'static> {
2018 index: usize,
2019 inlet: Inlet<T>,
2020 outlets: Vec<Outlet<T>>,
2021 state: Arc<Mutex<State<T>>>,
2022 }
2023
2024 impl<T> OutHandler for Out<T>
2025 where
2026 T: Clone + Send + 'static,
2027 {
2028 fn on_pull(
2029 &mut self,
2030 logic: &mut GraphStageLogic,
2031 _outlet: AnyOutlet,
2032 ) -> StreamResult<()> {
2033 let mut complete_now = false;
2034 let pending = {
2035 let mut state = self.state.lock().expect("partition state poisoned");
2036 if let Some((idx, _)) = &state.pending
2037 && *idx == self.index
2038 {
2039 state.pending.take()
2040 } else {
2041 None
2042 }
2043 };
2044 if let Some((_, item)) = pending {
2045 logic.push(&self.outlets[self.index], item)?;
2046 let state = self.state.lock().expect("partition state poisoned");
2047 if state.upstream_closed {
2048 complete_now = true;
2049 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2050 && !logic.has_been_pulled(&self.inlet)
2051 {
2052 logic.pull(&self.inlet)?;
2053 }
2054 } else {
2055 let state = self.state.lock().expect("partition state poisoned");
2056 if state.upstream_closed {
2057 complete_now = true;
2058 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2059 && !logic.has_been_pulled(&self.inlet)
2060 {
2061 logic.pull(&self.inlet)?;
2062 }
2063 }
2064 if complete_now {
2065 for outlet in &self.outlets {
2066 if !logic.is_closed(outlet) {
2067 logic.complete(outlet)?;
2068 }
2069 }
2070 }
2071 Ok(())
2072 }
2073
2074 fn on_downstream_finish(
2075 &mut self,
2076 logic: &mut GraphStageLogic,
2077 _outlet: AnyOutlet,
2078 ) -> StreamResult<()> {
2079 let (cancel_stage, clear_pending) = {
2080 let mut state = self.state.lock().expect("partition state poisoned");
2081 if state.cancelled[self.index] {
2082 return Ok(());
2083 }
2084 state.cancelled[self.index] = true;
2085 state.live_outlets -= 1;
2086 let clear_pending = state
2087 .pending
2088 .as_ref()
2089 .is_some_and(|(idx, _)| *idx == self.index);
2090 let cancel_stage = state.eager_cancel || state.live_outlets == 0;
2091 if clear_pending {
2092 state.pending = None;
2093 }
2094 (cancel_stage, clear_pending)
2095 };
2096 if cancel_stage {
2097 logic.complete_stage()?;
2098 } else if clear_pending
2099 && !logic.has_been_pulled(&self.inlet)
2100 && !logic.is_closed(&self.inlet)
2101 {
2102 let state = self.state.lock().expect("partition state poisoned");
2103 if any_live_demand(logic, &self.outlets, &state.cancelled) {
2104 logic.pull(&self.inlet)?;
2105 }
2106 }
2107 Ok(())
2108 }
2109 }
2110
2111 let inlet = shape.inlet();
2112 let outlets = shape.outlets_vec();
2113 let state = Arc::new(Mutex::new(State {
2114 pending: None,
2115 upstream_closed: false,
2116 live_outlets: outlets.len(),
2117 cancelled: vec![false; outlets.len()],
2118 eager_cancel: self.eager_cancel,
2119 }));
2120 let mut logic = GraphStageLogic::new(shape);
2121 logic
2122 .set_handler(
2123 &inlet,
2124 Box::new(In {
2125 inlet_id: inlet.id(),
2126 inlet: inlet.clone(),
2127 outlets: outlets.clone(),
2128 partitioner: Arc::clone(&self.partitioner),
2129 state: Arc::clone(&state),
2130 }),
2131 )
2132 .unwrap();
2133 for (index, outlet) in outlets.iter().cloned().enumerate() {
2134 logic
2135 .set_out_handler(
2136 &outlet,
2137 Box::new(Out {
2138 index,
2139 inlet: inlet.clone(),
2140 outlets: outlets.clone(),
2141 state: Arc::clone(&state),
2142 }),
2143 )
2144 .unwrap();
2145 }
2146 logic
2147 }
2148}
2149
2150#[derive(Clone, Debug)]
2151pub struct Unzip<A: 'static, B: 'static> {
2152 _marker: PhantomData<fn() -> (A, B)>,
2153}
2154
2155impl<A: 'static, B: 'static> Unzip<A, B> {
2156 #[must_use]
2157 pub fn new() -> Self {
2158 Self {
2159 _marker: PhantomData,
2160 }
2161 }
2162}
2163
2164impl<A: 'static, B: 'static> Default for Unzip<A, B> {
2165 fn default() -> Self {
2166 Self::new()
2167 }
2168}
2169
2170impl<A, B> GraphStage for Unzip<A, B>
2171where
2172 A: Clone + Send + 'static,
2173 B: Clone + Send + 'static,
2174{
2175 type Shape = FanOutShape2<(A, B), A, B>;
2176
2177 fn name(&self) -> &str {
2178 "Unzip"
2179 }
2180
2181 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2182 FanOutShape2::new(
2183 allocator.inlet("Unzip.in"),
2184 allocator.outlet("Unzip.out0"),
2185 allocator.outlet("Unzip.out1"),
2186 )
2187 }
2188
2189 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2190 let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
2191 let pair: (A, B) =
2192 downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
2193 (datum(pair.0), datum(pair.1))
2194 });
2195 #[allow(clippy::type_complexity)]
2200 let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
2201 Arc::new(|pair: (A, B)| pair);
2202 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2203 StageSpec::unzip(
2204 Arc::from(self.name()),
2205 shape.inlets(),
2206 shape.outlets(),
2207 split,
2208 typed_split,
2209 )
2210 }
2211
2212 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2213 UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
2214 }
2215}
2216
2217#[derive(Clone)]
2218pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
2219 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2220 _marker: PhantomData<fn(In) -> (Out0, Out1)>,
2221}
2222
2223impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
2224 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2225 f.debug_struct("UnzipWith").finish_non_exhaustive()
2226 }
2227}
2228
2229impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
2230 #[must_use]
2231 pub fn new<F>(split: F) -> Self
2232 where
2233 F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
2234 {
2235 Self {
2236 split: Arc::new(split),
2237 _marker: PhantomData,
2238 }
2239 }
2240}
2241
2242impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
2243where
2244 In: Clone + Send + 'static,
2245 Out0: Clone + Send + 'static,
2246 Out1: Clone + Send + 'static,
2247{
2248 type Shape = FanOutShape2<In, Out0, Out1>;
2249
2250 fn name(&self) -> &str {
2251 "UnzipWith"
2252 }
2253
2254 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2255 FanOutShape2::new(
2256 allocator.inlet("UnzipWith.in"),
2257 allocator.outlet("UnzipWith.out0"),
2258 allocator.outlet("UnzipWith.out1"),
2259 )
2260 }
2261
2262 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2263 let split_fn = Arc::clone(&self.split);
2264 let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
2265 let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
2266 .expect("unzip-with: wrong element type");
2267 let (out0, out1) = split_fn(value);
2268 (datum(out0), datum(out1))
2269 });
2270 let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
2273 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2274 StageSpec::unzip(
2275 Arc::from(self.name()),
2276 shape.inlets(),
2277 shape.outlets(),
2278 split,
2279 typed_split,
2280 )
2281 }
2282
2283 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2284 struct State {
2285 left_open: bool,
2286 right_open: bool,
2287 upstream_closed: bool,
2288 }
2289
2290 struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
2291 inlet_id: PortId,
2292 inlet: Inlet<In>,
2293 out0: Outlet<Out0>,
2294 out1: Outlet<Out1>,
2295 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2296 state: Arc<Mutex<State>>,
2297 }
2298
2299 impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
2300 where
2301 In: Clone + Send + 'static,
2302 Out0: Clone + Send + 'static,
2303 Out1: Clone + Send + 'static,
2304 {
2305 fn on_push(
2306 &mut self,
2307 logic: &mut GraphStageLogic,
2308 _inlet: AnyInlet,
2309 ) -> StreamResult<()> {
2310 let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
2311 downcast_datum(value, "grab", || {
2312 format!("inlet#{}", self.inlet_id.as_usize())
2313 })
2314 })?;
2315 let (left, right) = (self.split)(value);
2316 let state = self.state.lock().expect("unzip-with state poisoned");
2317 if state.left_open {
2318 logic.push(&self.out0, left)?;
2319 }
2320 if state.right_open {
2321 logic.push(&self.out1, right)?;
2322 }
2323 drop(state);
2324 let state = self.state.lock().expect("unzip-with state poisoned");
2325 let left_ready = !state.left_open || logic.is_available(&self.out0);
2326 let right_ready = !state.right_open || logic.is_available(&self.out1);
2327 if (state.left_open || state.right_open)
2328 && left_ready
2329 && right_ready
2330 && !logic.has_been_pulled(&self.inlet)
2331 {
2332 logic.pull(&self.inlet)?;
2333 }
2334 Ok(())
2335 }
2336
2337 fn on_upstream_finish(
2338 &mut self,
2339 logic: &mut GraphStageLogic,
2340 _inlet: AnyInlet,
2341 ) -> StreamResult<()> {
2342 self.state
2343 .lock()
2344 .expect("unzip-with state poisoned")
2345 .upstream_closed = true;
2346 if !logic.is_closed(&self.out0) {
2347 logic.complete(&self.out0)?;
2348 }
2349 if !logic.is_closed(&self.out1) {
2350 logic.complete(&self.out1)?;
2351 }
2352 Ok(())
2353 }
2354 }
2355
2356 struct Out<In: 'static, Out0: 'static, Out1: 'static> {
2357 is_left: bool,
2358 inlet: Inlet<In>,
2359 out0: Outlet<Out0>,
2360 out1: Outlet<Out1>,
2361 state: Arc<Mutex<State>>,
2362 }
2363
2364 impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
2365 where
2366 In: Clone + Send + 'static,
2367 Out0: Clone + Send + 'static,
2368 Out1: Clone + Send + 'static,
2369 {
2370 fn on_pull(
2371 &mut self,
2372 logic: &mut GraphStageLogic,
2373 _outlet: AnyOutlet,
2374 ) -> StreamResult<()> {
2375 let state = self.state.lock().expect("unzip-with state poisoned");
2376 let left_ready = !state.left_open || logic.is_available(&self.out0);
2377 let right_ready = !state.right_open || logic.is_available(&self.out1);
2378 if state.upstream_closed {
2379 drop(state);
2380 if !logic.is_closed(&self.out0) {
2381 logic.complete(&self.out0)?;
2382 }
2383 if !logic.is_closed(&self.out1) {
2384 logic.complete(&self.out1)?;
2385 }
2386 } else if (state.left_open || state.right_open)
2387 && left_ready
2388 && right_ready
2389 && !logic.has_been_pulled(&self.inlet)
2390 {
2391 drop(state);
2392 logic.pull(&self.inlet)?;
2393 }
2394 Ok(())
2395 }
2396
2397 fn on_downstream_finish(
2398 &mut self,
2399 logic: &mut GraphStageLogic,
2400 _outlet: AnyOutlet,
2401 ) -> StreamResult<()> {
2402 let mut state = self.state.lock().expect("unzip-with state poisoned");
2403 if self.is_left {
2404 state.left_open = false;
2405 } else {
2406 state.right_open = false;
2407 }
2408 if !state.left_open && !state.right_open {
2409 logic.complete_stage()?;
2410 return Ok(());
2411 }
2412 let left_ready = !state.left_open || logic.is_available(&self.out0);
2413 let right_ready = !state.right_open || logic.is_available(&self.out1);
2414 if !state.upstream_closed
2415 && (state.left_open || state.right_open)
2416 && left_ready
2417 && right_ready
2418 && !logic.has_been_pulled(&self.inlet)
2419 {
2420 logic.pull(&self.inlet)?;
2421 }
2422 Ok(())
2423 }
2424 }
2425
2426 let inlet = shape.inlet();
2427 let out0 = shape.out0();
2428 let out1 = shape.out1();
2429 let state = Arc::new(Mutex::new(State {
2430 left_open: true,
2431 right_open: true,
2432 upstream_closed: false,
2433 }));
2434 let mut logic = GraphStageLogic::new(shape);
2435 logic
2436 .set_handler(
2437 &inlet,
2438 Box::new(InHandlerState {
2439 inlet_id: inlet.id(),
2440 inlet: inlet.clone(),
2441 out0: out0.clone(),
2442 out1: out1.clone(),
2443 split: Arc::clone(&self.split),
2444 state: Arc::clone(&state),
2445 }),
2446 )
2447 .unwrap();
2448 logic
2449 .set_out_handler(
2450 &out0,
2451 Box::new(Out {
2452 is_left: true,
2453 inlet: inlet.clone(),
2454 out0: out0.clone(),
2455 out1: out1.clone(),
2456 state: Arc::clone(&state),
2457 }),
2458 )
2459 .unwrap();
2460 logic
2461 .set_out_handler(
2462 &out1.clone(),
2463 Box::new(Out {
2464 is_left: false,
2465 inlet: inlet.clone(),
2466 out0: out0.clone(),
2467 out1: out1.clone(),
2468 state,
2469 }),
2470 )
2471 .unwrap();
2472 logic
2473 }
2474}
2475
2476#[derive(Clone, Debug)]
2477pub struct AsyncBoundary<T: 'static> {
2478 _marker: PhantomData<fn() -> T>,
2479}
2480
2481impl<T: 'static> AsyncBoundary<T> {
2482 #[must_use]
2483 pub fn new() -> Self {
2484 Self {
2485 _marker: PhantomData,
2486 }
2487 }
2488}
2489
2490impl<T: 'static> Default for AsyncBoundary<T> {
2491 fn default() -> Self {
2492 Self::new()
2493 }
2494}
2495
2496impl<T> GraphStage for AsyncBoundary<T>
2497where
2498 T: Clone + Send + 'static,
2499{
2500 type Shape = FlowShape<T, T>;
2501
2502 fn name(&self) -> &str {
2503 "AsyncBoundary"
2504 }
2505
2506 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
2507 let first_id = next_port_id_block(2);
2508 FlowShape::new(
2509 Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
2510 Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
2511 )
2512 }
2513
2514 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2515 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
2516 }
2517
2518 fn stage_spec_with_ports(
2519 &self,
2520 _shape: &Self::Shape,
2521 inlets: Vec<AnyInlet>,
2522 outlets: Vec<AnyOutlet>,
2523 ) -> StageSpec {
2524 StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
2525 }
2526}