1use super::*;
16
17#[derive(Clone, Debug)]
19pub struct Identity<T: 'static> {
20 _marker: PhantomData<fn() -> T>,
21}
22
23impl<T: 'static> Identity<T> {
24 #[must_use]
25 pub fn new() -> Self {
26 Self {
27 _marker: PhantomData,
28 }
29 }
30}
31
32impl<T: 'static> Default for Identity<T> {
33 fn default() -> Self {
34 Self::new()
35 }
36}
37
38impl<T> GraphStage for Identity<T>
39where
40 T: Clone + Send + 'static,
41{
42 type Shape = FlowShape<T, T>;
43
44 fn name(&self) -> &str {
45 "Identity"
46 }
47
48 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
49 let first_id = next_port_id_block(2);
50 FlowShape::new(
51 Inlet::with_arc_name(first_id, identity_inlet_name()),
52 Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
53 )
54 }
55
56 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
57 StageSpec::identity(shape.inlets(), shape.outlets())
58 }
59
60 fn stage_spec_with_ports(
61 &self,
62 _shape: &Self::Shape,
63 inlets: Vec<AnyInlet>,
64 outlets: Vec<AnyOutlet>,
65 ) -> StageSpec {
66 StageSpec::identity(inlets, outlets)
67 }
68}
69
70#[derive(Clone)]
72pub struct MapStage<In: 'static, Out: 'static> {
73 f: Arc<dyn Fn(In) -> Out + Send + Sync>,
74 _marker: PhantomData<fn(In) -> Out>,
75}
76
77impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("MapStage")
80 .field("name", &"Map")
81 .finish_non_exhaustive()
82 }
83}
84
85impl<In: 'static, Out: 'static> MapStage<In, Out> {
86 #[must_use]
87 pub fn new<F>(f: F) -> Self
88 where
89 F: Fn(In) -> Out + Send + Sync + 'static,
90 {
91 Self {
92 f: Arc::new(f),
93 _marker: PhantomData,
94 }
95 }
96}
97
98impl<In, Out> GraphStage for MapStage<In, Out>
99where
100 In: Clone + Send + 'static,
101 Out: Clone + Send + 'static,
102{
103 type Shape = FlowShape<In, Out>;
104
105 fn name(&self) -> &str {
106 "Map"
107 }
108
109 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
110 let first_id = next_port_id_block(2);
111 FlowShape::new(
112 Inlet::with_arc_name(first_id, map_inlet_name()),
113 Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
114 )
115 }
116
117 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
118 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
119 }
120
121 fn stage_spec_with_ports(
122 &self,
123 _shape: &Self::Shape,
124 inlets: Vec<AnyInlet>,
125 outlets: Vec<AnyOutlet>,
126 ) -> StageSpec {
127 let f = Arc::clone(&self.f);
128 let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
129 let mapper = Arc::new(move |value: DatumValue| {
130 let value: In = downcast_datum(value, "map", || "Map.in")?;
131 Ok(datum(f(value)))
132 });
133 StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
134 }
135}
136
137#[derive(Clone, Debug)]
142pub struct Buffer<T: 'static> {
143 capacity: usize,
144 strategy: OverflowStrategy,
145 _marker: PhantomData<fn() -> T>,
146}
147
148impl<T: 'static> Buffer<T> {
149 #[must_use]
150 pub fn new(capacity: usize, strategy: OverflowStrategy) -> Self {
151 assert!(capacity > 0, "buffer capacity must be greater than zero");
152 Self {
153 capacity,
154 strategy,
155 _marker: PhantomData,
156 }
157 }
158}
159
160impl<T> GraphStage for Buffer<T>
161where
162 T: Clone + Send + 'static,
163{
164 type Shape = FlowShape<T, T>;
165
166 fn name(&self) -> &str {
167 "Buffer"
168 }
169
170 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
171 let first_id = next_port_id_block(2);
172 FlowShape::new(
173 Inlet::with_id(first_id, "Buffer.in"),
174 Outlet::with_id(first_id.offset(1), "Buffer.out"),
175 )
176 }
177
178 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
179 StageSpec::opaque(self.name(), shape.inlets(), shape.outlets())
183 .with_typed_cyclic(TypedCyclicOp::BufferPassthrough)
184 }
185
186 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
187 struct State<T> {
188 queue: VecDeque<T>,
189 upstream_closed: bool,
190 }
191
192 fn lock_state<T>(
193 state: &Arc<Mutex<State<T>>>,
194 ) -> StreamResult<std::sync::MutexGuard<'_, State<T>>> {
195 state
196 .lock()
197 .map_err(|_| StreamError::Failed("graph buffer state poisoned".into()))
198 }
199
200 fn enqueue<T>(
201 state: &mut State<T>,
202 value: T,
203 capacity: usize,
204 strategy: OverflowStrategy,
205 ) -> StreamResult<()> {
206 if state.queue.len() < capacity {
207 state.queue.push_back(value);
208 return Ok(());
209 }
210
211 match strategy {
212 OverflowStrategy::Backpressure | OverflowStrategy::Fail => Err(
213 StreamError::Failed(format!("Buffer overflow (max capacity was: {capacity})!")),
214 ),
215 OverflowStrategy::DropHead => {
216 let _ = state.queue.pop_front();
217 state.queue.push_back(value);
218 Ok(())
219 }
220 OverflowStrategy::DropTail => {
221 let _ = state.queue.pop_back();
222 state.queue.push_back(value);
223 Ok(())
224 }
225 OverflowStrategy::DropBuffer => {
226 state.queue.clear();
227 state.queue.push_back(value);
228 Ok(())
229 }
230 OverflowStrategy::DropNew => Ok(()),
231 }
232 }
233
234 fn drain_one<T>(
235 logic: &mut GraphStageLogic,
236 outlet: &Outlet<T>,
237 state: &Arc<Mutex<State<T>>>,
238 ) -> StreamResult<()>
239 where
240 T: Clone + Send + 'static,
241 {
242 if !logic.is_available(outlet) {
243 return Ok(());
244 }
245 let next = lock_state(state)?.queue.pop_front();
246 if let Some(value) = next {
247 logic.push(outlet, value)?;
248 }
249 Ok(())
250 }
251
252 fn maybe_pull<T>(
253 logic: &mut GraphStageLogic,
254 inlet: &Inlet<T>,
255 state: &Arc<Mutex<State<T>>>,
256 capacity: usize,
257 ) -> StreamResult<()>
258 where
259 T: 'static,
260 {
261 let can_pull = {
262 let state = lock_state(state)?;
263 !state.upstream_closed && state.queue.len() < capacity
264 };
265 if can_pull && !logic.has_been_pulled(inlet) {
266 logic.pull(inlet)?;
267 }
268 Ok(())
269 }
270
271 fn maybe_complete<T>(
272 logic: &mut GraphStageLogic,
273 outlet: &Outlet<T>,
274 state: &Arc<Mutex<State<T>>>,
275 ) -> StreamResult<()>
276 where
277 T: 'static,
278 {
279 let done = {
280 let state = lock_state(state)?;
281 state.upstream_closed && state.queue.is_empty()
282 };
283 if done && !logic.is_closed(outlet) {
284 logic.complete(outlet)?;
285 }
286 Ok(())
287 }
288
289 struct In<T: 'static> {
290 inlet: Inlet<T>,
291 outlet: Outlet<T>,
292 state: Arc<Mutex<State<T>>>,
293 capacity: usize,
294 strategy: OverflowStrategy,
295 }
296
297 impl<T> InHandler for In<T>
298 where
299 T: Clone + Send + 'static,
300 {
301 fn on_push(
302 &mut self,
303 logic: &mut GraphStageLogic,
304 _inlet: AnyInlet,
305 ) -> StreamResult<()> {
306 let value = logic.grab(&self.inlet)?;
307 {
308 let mut state = lock_state(&self.state)?;
309 enqueue(&mut state, value, self.capacity, self.strategy)?;
310 }
311 drain_one(logic, &self.outlet, &self.state)?;
312 maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
313 maybe_complete(logic, &self.outlet, &self.state)
314 }
315
316 fn on_upstream_finish(
317 &mut self,
318 logic: &mut GraphStageLogic,
319 _inlet: AnyInlet,
320 ) -> StreamResult<()> {
321 lock_state(&self.state)?.upstream_closed = true;
322 drain_one(logic, &self.outlet, &self.state)?;
323 maybe_complete(logic, &self.outlet, &self.state)
324 }
325 }
326
327 struct Out<T: 'static> {
328 inlet: Inlet<T>,
329 outlet: Outlet<T>,
330 state: Arc<Mutex<State<T>>>,
331 capacity: usize,
332 }
333
334 impl<T> OutHandler for Out<T>
335 where
336 T: Clone + Send + 'static,
337 {
338 fn on_pull(
339 &mut self,
340 logic: &mut GraphStageLogic,
341 _outlet: AnyOutlet,
342 ) -> StreamResult<()> {
343 drain_one(logic, &self.outlet, &self.state)?;
344 maybe_pull(logic, &self.inlet, &self.state, self.capacity)?;
345 maybe_complete(logic, &self.outlet, &self.state)
346 }
347
348 fn on_downstream_finish(
349 &mut self,
350 logic: &mut GraphStageLogic,
351 _outlet: AnyOutlet,
352 ) -> StreamResult<()> {
353 logic.complete_stage()
354 }
355 }
356
357 let state = Arc::new(Mutex::new(State {
358 queue: VecDeque::new(),
359 upstream_closed: false,
360 }));
361 let mut logic = GraphStageLogic::new(shape);
362 logic.pull(&shape.inlet()).unwrap();
363 logic
364 .set_handler(
365 &shape.inlet(),
366 Box::new(In {
367 inlet: shape.inlet(),
368 outlet: shape.outlet(),
369 state: Arc::clone(&state),
370 capacity: self.capacity,
371 strategy: self.strategy,
372 }),
373 )
374 .unwrap();
375 logic
376 .set_out_handler(
377 &shape.outlet(),
378 Box::new(Out {
379 inlet: shape.inlet(),
380 outlet: shape.outlet(),
381 state,
382 capacity: self.capacity,
383 }),
384 )
385 .unwrap();
386 logic
387 }
388}
389
390#[derive(Clone)]
392pub struct TakeWhile<T: 'static> {
393 predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
394 _marker: PhantomData<fn() -> T>,
395}
396
397impl<T: 'static> fmt::Debug for TakeWhile<T> {
398 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
399 f.debug_struct("TakeWhile")
400 .field("name", &"TakeWhile")
401 .finish_non_exhaustive()
402 }
403}
404
405impl<T: 'static> TakeWhile<T> {
406 #[must_use]
407 pub fn new<F>(predicate: F) -> Self
408 where
409 F: Fn(&T) -> bool + Send + Sync + 'static,
410 {
411 Self {
412 predicate: Arc::new(predicate),
413 _marker: PhantomData,
414 }
415 }
416}
417
418impl<T> GraphStage for TakeWhile<T>
419where
420 T: Clone + Send + 'static,
421{
422 type Shape = FlowShape<T, T>;
423
424 fn name(&self) -> &str {
425 "TakeWhile"
426 }
427
428 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
429 let first_id = next_port_id_block(2);
430 FlowShape::new(
431 Inlet::with_id(first_id, "TakeWhile.in"),
432 Outlet::with_id(first_id.offset(1), "TakeWhile.out"),
433 )
434 }
435
436 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
437 let predicate: Arc<dyn Fn(&T) -> bool + Send + Sync> = Arc::clone(&self.predicate);
442 StageSpec::opaque(self.name(), shape.inlets(), shape.outlets()).with_typed_cyclic(
443 TypedCyclicOp::TakeWhile(Arc::new(predicate) as Arc<dyn Any + Send + Sync>),
444 )
445 }
446
447 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
448 struct In<T: 'static> {
449 inlet: Inlet<T>,
450 outlet: Outlet<T>,
451 predicate: Arc<dyn Fn(&T) -> bool + Send + Sync>,
452 }
453
454 impl<T> InHandler for In<T>
455 where
456 T: Clone + Send + 'static,
457 {
458 fn on_push(
459 &mut self,
460 logic: &mut GraphStageLogic,
461 _inlet: AnyInlet,
462 ) -> StreamResult<()> {
463 let value = logic.grab(&self.inlet)?;
464 if (self.predicate)(&value) {
465 logic.emit(&self.outlet, value)
466 } else if logic.is_closed(&self.outlet) {
467 Ok(())
468 } else {
469 logic.complete(&self.outlet)
470 }
471 }
472
473 fn on_upstream_finish(
474 &mut self,
475 logic: &mut GraphStageLogic,
476 _inlet: AnyInlet,
477 ) -> StreamResult<()> {
478 if logic.is_closed(&self.outlet) {
479 Ok(())
480 } else {
481 logic.complete(&self.outlet)
482 }
483 }
484 }
485
486 struct Out<T: 'static> {
487 inlet: Inlet<T>,
488 }
489
490 impl<T> OutHandler for Out<T>
491 where
492 T: Clone + Send + 'static,
493 {
494 fn on_pull(
495 &mut self,
496 logic: &mut GraphStageLogic,
497 _outlet: AnyOutlet,
498 ) -> StreamResult<()> {
499 if !logic.has_been_pulled(&self.inlet) {
500 logic.pull(&self.inlet)?;
501 }
502 Ok(())
503 }
504
505 fn on_downstream_finish(
506 &mut self,
507 logic: &mut GraphStageLogic,
508 _outlet: AnyOutlet,
509 ) -> StreamResult<()> {
510 logic.complete_stage()
511 }
512 }
513
514 let mut logic = GraphStageLogic::new(shape);
515 logic
516 .set_handler(
517 &shape.inlet(),
518 Box::new(In {
519 inlet: shape.inlet(),
520 outlet: shape.outlet(),
521 predicate: Arc::clone(&self.predicate),
522 }),
523 )
524 .unwrap();
525 logic
526 .set_out_handler(
527 &shape.outlet(),
528 Box::new(Out {
529 inlet: shape.inlet(),
530 }),
531 )
532 .unwrap();
533 logic
534 }
535}
536
537#[derive(Clone, Debug)]
539pub struct Broadcast<T: 'static> {
540 outputs: usize,
541 _marker: PhantomData<fn() -> T>,
542}
543
544impl<T: 'static> Broadcast<T> {
545 #[must_use]
546 pub fn new(outputs: usize) -> Self {
547 assert!(
548 outputs > 0,
549 "broadcast output count must be greater than zero"
550 );
551 Self {
552 outputs,
553 _marker: PhantomData,
554 }
555 }
556}
557
558impl<T> GraphStage for Broadcast<T>
559where
560 T: Clone + Send + 'static,
561{
562 type Shape = FanOutShape<T, T>;
563
564 fn name(&self) -> &str {
565 "Broadcast"
566 }
567
568 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
569 let inlet = allocator.inlet_arc(broadcast_inlet_name());
570 let outlets = (0..self.outputs)
571 .map(|index| allocator.outlet(format!("Broadcast.out{index}")))
572 .collect();
573 FanOutShape::new(inlet, outlets)
574 }
575
576 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
577 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
578 }
579
580 fn stage_spec_with_ports(
581 &self,
582 _shape: &Self::Shape,
583 inlets: Vec<AnyInlet>,
584 outlets: Vec<AnyOutlet>,
585 ) -> StageSpec {
586 StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
587 }
588}
589
590#[derive(Clone, Debug)]
593pub struct Balance<T: 'static> {
594 outputs: usize,
595 _marker: PhantomData<fn() -> T>,
596}
597
598impl<T: 'static> Balance<T> {
599 #[must_use]
600 pub fn new(outputs: usize) -> Self {
601 assert!(
602 outputs > 0,
603 "balance output count must be greater than zero"
604 );
605 Self {
606 outputs,
607 _marker: PhantomData,
608 }
609 }
610}
611
612impl<T> GraphStage for Balance<T>
613where
614 T: Clone + Send + 'static,
615{
616 type Shape = FanOutShape<T, T>;
617
618 fn name(&self) -> &str {
619 "Balance"
620 }
621
622 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
623 let inlet = allocator.inlet_arc(balance_inlet_name());
624 let outlets = (0..self.outputs)
625 .map(|index| allocator.outlet(format!("Balance.out{index}")))
626 .collect();
627 FanOutShape::new(inlet, outlets)
628 }
629
630 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
631 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
632 }
633
634 fn stage_spec_with_ports(
635 &self,
636 _shape: &Self::Shape,
637 inlets: Vec<AnyInlet>,
638 outlets: Vec<AnyOutlet>,
639 ) -> StageSpec {
640 StageSpec::balance(balance_stage_name(), inlets, outlets)
641 }
642}
643
644#[derive(Clone, Debug)]
647pub struct Merge<T: 'static> {
648 inputs: usize,
649 _marker: PhantomData<fn() -> T>,
650}
651
652impl<T: 'static> Merge<T> {
653 #[must_use]
654 pub fn new(inputs: usize) -> Self {
655 assert!(inputs > 0, "merge input count must be greater than zero");
656 Self {
657 inputs,
658 _marker: PhantomData,
659 }
660 }
661}
662
663impl<T> GraphStage for Merge<T>
664where
665 T: Clone + Send + 'static,
666{
667 type Shape = FanInShape<T, T>;
668
669 fn name(&self) -> &str {
670 "Merge"
671 }
672
673 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
674 let inlets = (0..self.inputs)
675 .map(|index| allocator.inlet(format!("Merge.in{index}")))
676 .collect();
677 FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
678 }
679
680 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
681 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
682 }
683
684 fn stage_spec_with_ports(
685 &self,
686 _shape: &Self::Shape,
687 inlets: Vec<AnyInlet>,
688 outlets: Vec<AnyOutlet>,
689 ) -> StageSpec {
690 StageSpec::merge(merge_stage_name(), inlets, outlets)
691 }
692}
693
694#[derive(Clone, Debug)]
697pub struct Concat<T: 'static> {
698 inputs: usize,
699 _marker: PhantomData<fn() -> T>,
700}
701
702impl<T: 'static> Concat<T> {
703 #[must_use]
704 pub fn new(inputs: usize) -> Self {
705 assert!(inputs > 1, "concat input count must be greater than one");
706 Self {
707 inputs,
708 _marker: PhantomData,
709 }
710 }
711}
712
713impl<T> GraphStage for Concat<T>
714where
715 T: Clone + Send + 'static,
716{
717 type Shape = FanInShape<T, T>;
718
719 fn name(&self) -> &str {
720 "Concat"
721 }
722
723 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
724 let inlets = (0..self.inputs)
725 .map(|index| allocator.inlet(format!("Concat.in{index}")))
726 .collect();
727 FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
728 }
729
730 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
731 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
732 }
733
734 fn stage_spec_with_ports(
735 &self,
736 _shape: &Self::Shape,
737 inlets: Vec<AnyInlet>,
738 outlets: Vec<AnyOutlet>,
739 ) -> StageSpec {
740 StageSpec::concat(concat_stage_name(), inlets, outlets)
741 }
742}
743
744#[derive(Clone, Debug)]
747pub struct OrElse<T: 'static> {
748 _marker: PhantomData<fn() -> T>,
749}
750
751impl<T: 'static> OrElse<T> {
752 #[must_use]
753 pub fn new() -> Self {
754 Self {
755 _marker: PhantomData,
756 }
757 }
758}
759
760impl<T: 'static> Default for OrElse<T> {
761 fn default() -> Self {
762 Self::new()
763 }
764}
765
766impl<T> GraphStage for OrElse<T>
767where
768 T: Clone + Send + 'static,
769{
770 type Shape = FanInShape<T, T>;
771
772 fn name(&self) -> &str {
773 "OrElse"
774 }
775
776 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
777 let inlets = vec![
778 allocator.inlet_arc(or_else_primary_name()),
779 allocator.inlet_arc(or_else_secondary_name()),
780 ];
781 FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
782 }
783
784 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
785 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
786 }
787
788 fn stage_spec_with_ports(
789 &self,
790 _shape: &Self::Shape,
791 inlets: Vec<AnyInlet>,
792 outlets: Vec<AnyOutlet>,
793 ) -> StageSpec {
794 StageSpec::or_else(or_else_stage_name(), inlets, outlets)
795 }
796}
797
798#[derive(Clone, Debug)]
801pub struct Interleave<T: 'static> {
802 inputs: usize,
803 segment_size: usize,
804 eager_close: bool,
805 _marker: PhantomData<fn() -> T>,
806}
807
808impl<T: 'static> Interleave<T> {
809 #[must_use]
810 pub fn new(inputs: usize, segment_size: usize) -> Self {
811 Self::new_with_eager_close(inputs, segment_size, false)
812 }
813
814 #[must_use]
815 pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
816 assert!(
817 inputs > 1,
818 "interleave input count must be greater than one"
819 );
820 assert!(
821 segment_size > 0,
822 "interleave segment size must be greater than zero"
823 );
824 Self {
825 inputs,
826 segment_size,
827 eager_close,
828 _marker: PhantomData,
829 }
830 }
831}
832
833impl<T> GraphStage for Interleave<T>
834where
835 T: Clone + Send + 'static,
836{
837 type Shape = FanInShape<T, T>;
838
839 fn name(&self) -> &str {
840 "Interleave"
841 }
842
843 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
844 let inlets = (0..self.inputs)
845 .map(|index| allocator.inlet(format!("Interleave.in{index}")))
846 .collect();
847 FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
848 }
849
850 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
851 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
852 }
853
854 fn stage_spec_with_ports(
855 &self,
856 _shape: &Self::Shape,
857 inlets: Vec<AnyInlet>,
858 outlets: Vec<AnyOutlet>,
859 ) -> StageSpec {
860 StageSpec::interleave(
861 interleave_stage_name(),
862 inlets,
863 outlets,
864 self.segment_size,
865 self.eager_close,
866 )
867 }
868}
869
870#[derive(Clone, Debug)]
873pub struct MergePreferred<T: 'static> {
874 secondary_ports: usize,
875 _marker: PhantomData<fn() -> T>,
876}
877
878impl<T: 'static> MergePreferred<T> {
879 #[must_use]
880 pub fn new(secondary_ports: usize) -> Self {
881 assert!(
882 secondary_ports > 0,
883 "merge-preferred secondary input count must be greater than zero"
884 );
885 Self {
886 secondary_ports,
887 _marker: PhantomData,
888 }
889 }
890}
891
892impl<T> GraphStage for MergePreferred<T>
893where
894 T: Clone + Send + 'static,
895{
896 type Shape = MergePreferredShape<T>;
897
898 fn name(&self) -> &str {
899 "MergePreferred"
900 }
901
902 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
903 let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
904 let secondary = (0..self.secondary_ports)
905 .map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
906 .collect();
907 MergePreferredShape::new(
908 preferred,
909 secondary,
910 allocator.outlet_arc(merge_preferred_outlet_name()),
911 )
912 }
913
914 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
915 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
916 }
917
918 fn stage_spec_with_ports(
919 &self,
920 _shape: &Self::Shape,
921 inlets: Vec<AnyInlet>,
922 outlets: Vec<AnyOutlet>,
923 ) -> StageSpec {
924 StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
925 }
926}
927
928#[derive(Clone, Debug)]
932pub struct MergePrioritized<T: 'static> {
933 weights: Vec<usize>,
934 _marker: PhantomData<fn() -> T>,
935}
936
937impl<T: 'static> MergePrioritized<T> {
938 #[must_use]
939 pub fn new(weights: Vec<usize>) -> Self {
940 assert!(!weights.is_empty(), "prioritized merge must have inputs");
941 assert!(
942 weights.iter().all(|weight| *weight > 0),
943 "prioritized merge weights must be greater than zero"
944 );
945 Self {
946 weights,
947 _marker: PhantomData,
948 }
949 }
950}
951
952impl<T> GraphStage for MergePrioritized<T>
953where
954 T: Clone + Send + 'static,
955{
956 type Shape = FanInShape<T, T>;
957
958 fn name(&self) -> &str {
959 "MergePrioritized"
960 }
961
962 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
963 let inlets = (0..self.weights.len())
964 .map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
965 .collect();
966 FanInShape::new(
967 inlets,
968 allocator.outlet_arc(merge_prioritized_outlet_name()),
969 )
970 }
971
972 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
973 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
974 }
975
976 fn stage_spec_with_ports(
977 &self,
978 _shape: &Self::Shape,
979 inlets: Vec<AnyInlet>,
980 outlets: Vec<AnyOutlet>,
981 ) -> StageSpec {
982 StageSpec::merge_prioritized(
983 merge_prioritized_stage_name(),
984 inlets,
985 outlets,
986 self.weights.clone(),
987 )
988 }
989}
990
991#[derive(Clone, Debug)]
994pub struct Zip<Left: 'static, Right: 'static> {
995 _marker: PhantomData<fn() -> (Left, Right)>,
996}
997
998impl<Left: 'static, Right: 'static> Zip<Left, Right> {
999 #[must_use]
1000 pub fn new() -> Self {
1001 Self {
1002 _marker: PhantomData,
1003 }
1004 }
1005}
1006
1007impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
1008 fn default() -> Self {
1009 Self::new()
1010 }
1011}
1012
1013impl<Left, Right> GraphStage for Zip<Left, Right>
1014where
1015 Left: Clone + Send + 'static,
1016 Right: Clone + Send + 'static,
1017{
1018 type Shape = ZipShape<Left, Right>;
1019
1020 fn name(&self) -> &str {
1021 "Zip"
1022 }
1023
1024 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
1025 let first_id = next_port_id_block(3);
1026 ZipShape::new(
1027 Inlet::with_arc_name(first_id, zip_in0_name()),
1028 Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
1029 Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
1030 )
1031 }
1032
1033 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1034 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
1035 }
1036
1037 fn stage_spec_with_ports(
1038 &self,
1039 _shape: &Self::Shape,
1040 inlets: Vec<AnyInlet>,
1041 outlets: Vec<AnyOutlet>,
1042 ) -> StageSpec {
1043 let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
1044 let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
1045 let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
1046 Ok(datum((left, right)))
1047 });
1048 StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
1049 }
1050}
1051
1052#[derive(Clone, Debug)]
1055pub struct MergeSorted<T: 'static> {
1056 _marker: PhantomData<fn() -> T>,
1057}
1058
1059impl<T: 'static> MergeSorted<T> {
1060 #[must_use]
1061 pub fn new() -> Self {
1062 Self {
1063 _marker: PhantomData,
1064 }
1065 }
1066}
1067
1068impl<T: 'static> Default for MergeSorted<T> {
1069 fn default() -> Self {
1070 Self::new()
1071 }
1072}
1073
1074impl<T> GraphStage for MergeSorted<T>
1075where
1076 T: Clone + Ord + Send + 'static,
1077{
1078 type Shape = FanInShape<T, T>;
1079
1080 fn name(&self) -> &str {
1081 "MergeSorted"
1082 }
1083
1084 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1085 let inlets = vec![
1086 allocator.inlet("MergeSorted.in0"),
1087 allocator.inlet("MergeSorted.in1"),
1088 ];
1089 FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
1090 }
1091
1092 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1093 let compare = Arc::new(
1094 move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
1095 let a_t: &T = a
1096 .as_any_ref()
1097 .downcast_ref::<T>()
1098 .expect("merge-sorted compare: wrong element type");
1099 let b_t: &T = b
1100 .as_any_ref()
1101 .downcast_ref::<T>()
1102 .expect("merge-sorted compare: wrong element type");
1103 a_t.cmp(b_t)
1104 },
1105 );
1106 StageSpec::merge_sorted(
1107 Arc::from(self.name()),
1108 shape.inlets(),
1109 shape.outlets(),
1110 compare,
1111 )
1112 }
1113
1114 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1115 struct State<T> {
1116 left: VecDeque<T>,
1117 right: VecDeque<T>,
1118 left_closed: bool,
1119 right_closed: bool,
1120 pending: VecDeque<T>,
1121 }
1122
1123 impl<T> Default for State<T> {
1124 fn default() -> Self {
1125 Self {
1126 left: VecDeque::new(),
1127 right: VecDeque::new(),
1128 left_closed: false,
1129 right_closed: false,
1130 pending: VecDeque::new(),
1131 }
1132 }
1133 }
1134
1135 fn maybe_queue_output<T>(state: &mut State<T>) -> bool
1136 where
1137 T: Clone + Ord,
1138 {
1139 let next = match (state.left.front(), state.right.front()) {
1140 (Some(left), Some(right)) => {
1141 if left <= right {
1142 state.left.pop_front()
1143 } else {
1144 state.right.pop_front()
1145 }
1146 }
1147 (Some(_), None) if state.right_closed => state.left.pop_front(),
1148 (None, Some(_)) if state.left_closed => state.right.pop_front(),
1149 _ => None,
1150 };
1151 if let Some(value) = next {
1152 state.pending.push_back(value);
1153 true
1154 } else {
1155 false
1156 }
1157 }
1158
1159 fn maybe_complete<T>(
1160 logic: &mut GraphStageLogic,
1161 outlet: &Outlet<T>,
1162 state: &State<T>,
1163 ) -> StreamResult<()>
1164 where
1165 T: Clone + Send + 'static,
1166 {
1167 if state.left_closed
1168 && state.right_closed
1169 && state.left.is_empty()
1170 && state.right.is_empty()
1171 && state.pending.is_empty()
1172 && !logic.is_closed(outlet)
1173 {
1174 logic.complete(outlet)?;
1175 }
1176 Ok(())
1177 }
1178
1179 fn maybe_pull<T>(
1180 logic: &mut GraphStageLogic,
1181 left: &Inlet<T>,
1182 right: &Inlet<T>,
1183 state: &State<T>,
1184 ) -> StreamResult<()>
1185 where
1186 T: Clone + Send + 'static,
1187 {
1188 if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
1189 logic.pull(left)?;
1190 }
1191 if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
1192 logic.pull(right)?;
1193 }
1194 Ok(())
1195 }
1196
1197 fn maybe_drain<T>(
1198 logic: &mut GraphStageLogic,
1199 outlet: &Outlet<T>,
1200 state: &Arc<Mutex<State<T>>>,
1201 ) -> StreamResult<()>
1202 where
1203 T: Clone + Ord + Send + 'static,
1204 {
1205 let next = if logic.is_available(outlet) {
1206 state
1207 .lock()
1208 .expect("merge-sorted state poisoned")
1209 .pending
1210 .pop_front()
1211 } else {
1212 None
1213 };
1214 if let Some(value) = next {
1215 logic.push(outlet, value)?;
1216 }
1217 Ok(())
1218 }
1219
1220 struct In<T: 'static> {
1221 inlet_id: PortId,
1222 left: Inlet<T>,
1223 right: Inlet<T>,
1224 outlet: Outlet<T>,
1225 state: Arc<Mutex<State<T>>>,
1226 }
1227
1228 impl<T> InHandler for In<T>
1229 where
1230 T: Clone + Ord + Send + 'static,
1231 {
1232 fn on_push(
1233 &mut self,
1234 logic: &mut GraphStageLogic,
1235 _inlet: AnyInlet,
1236 ) -> StreamResult<()> {
1237 let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1238 downcast_datum(value, "grab", || {
1239 format!("inlet#{}", self.inlet_id.as_usize())
1240 })
1241 })?;
1242 {
1243 let mut state = self.state.lock().expect("merge-sorted state poisoned");
1244 if self.inlet_id == self.left.id() {
1245 state.left.push_back(value);
1246 } else {
1247 state.right.push_back(value);
1248 }
1249 while maybe_queue_output(&mut state) {}
1250 }
1251 maybe_drain(logic, &self.outlet, &self.state)?;
1252 let state = self.state.lock().expect("merge-sorted state poisoned");
1253 maybe_complete(logic, &self.outlet, &state)?;
1254 maybe_pull(logic, &self.left, &self.right, &state)
1255 }
1256
1257 fn on_upstream_finish(
1258 &mut self,
1259 logic: &mut GraphStageLogic,
1260 _inlet: AnyInlet,
1261 ) -> StreamResult<()> {
1262 {
1263 let mut state = self.state.lock().expect("merge-sorted state poisoned");
1264 if self.inlet_id == self.left.id() {
1265 state.left_closed = true;
1266 } else {
1267 state.right_closed = true;
1268 }
1269 while maybe_queue_output(&mut state) {}
1270 }
1271 maybe_drain(logic, &self.outlet, &self.state)?;
1272 let state = self.state.lock().expect("merge-sorted state poisoned");
1273 maybe_complete(logic, &self.outlet, &state)?;
1274 maybe_pull(logic, &self.left, &self.right, &state)
1275 }
1276 }
1277
1278 struct Out<T: 'static> {
1279 left: Inlet<T>,
1280 right: Inlet<T>,
1281 outlet: Outlet<T>,
1282 state: Arc<Mutex<State<T>>>,
1283 }
1284
1285 impl<T> OutHandler for Out<T>
1286 where
1287 T: Clone + Ord + Send + 'static,
1288 {
1289 fn on_pull(
1290 &mut self,
1291 logic: &mut GraphStageLogic,
1292 _outlet: AnyOutlet,
1293 ) -> StreamResult<()> {
1294 maybe_drain(logic, &self.outlet, &self.state)?;
1295 let state = self.state.lock().expect("merge-sorted state poisoned");
1296 maybe_complete(logic, &self.outlet, &state)?;
1297 maybe_pull(logic, &self.left, &self.right, &state)
1298 }
1299 }
1300
1301 let state = Arc::new(Mutex::new(State::<T>::default()));
1302 let left = shape.inlet(0).expect("merge-sorted left inlet");
1303 let right = shape.inlet(1).expect("merge-sorted right inlet");
1304 let outlet = shape.outlet();
1305 let mut logic = GraphStageLogic::new(shape);
1306 logic
1307 .set_handler(
1308 &left,
1309 Box::new(In {
1310 inlet_id: left.id(),
1311 left: left.clone(),
1312 right: right.clone(),
1313 outlet: outlet.clone(),
1314 state: Arc::clone(&state),
1315 }),
1316 )
1317 .unwrap();
1318 logic
1319 .set_handler(
1320 &right,
1321 Box::new(In {
1322 inlet_id: right.id(),
1323 left: left.clone(),
1324 right: right.clone(),
1325 outlet: outlet.clone(),
1326 state: Arc::clone(&state),
1327 }),
1328 )
1329 .unwrap();
1330 logic
1331 .set_out_handler(
1332 &outlet.clone(),
1333 Box::new(Out {
1334 left,
1335 right,
1336 outlet: outlet.clone(),
1337 state,
1338 }),
1339 )
1340 .unwrap();
1341 logic
1342 }
1343}
1344
1345#[derive(Clone)]
1349pub struct MergeSequence<T: 'static> {
1350 inputs: usize,
1351 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1352 _marker: PhantomData<fn() -> T>,
1353}
1354
1355impl<T: 'static> fmt::Debug for MergeSequence<T> {
1356 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1357 f.debug_struct("MergeSequence")
1358 .field("inputs", &self.inputs)
1359 .finish_non_exhaustive()
1360 }
1361}
1362
1363impl<T: 'static> MergeSequence<T> {
1364 #[must_use]
1365 pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
1366 where
1367 F: Fn(&T) -> u64 + Send + Sync + 'static,
1368 {
1369 assert!(
1370 inputs > 1,
1371 "merge sequence input count must be greater than one"
1372 );
1373 Self {
1374 inputs,
1375 extract_sequence: Arc::new(extract_sequence),
1376 _marker: PhantomData,
1377 }
1378 }
1379}
1380
1381impl<T> GraphStage for MergeSequence<T>
1382where
1383 T: Clone + Send + 'static,
1384{
1385 type Shape = FanInShape<T, T>;
1386
1387 fn name(&self) -> &str {
1388 "MergeSequence"
1389 }
1390
1391 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1392 let inlets = (0..self.inputs)
1393 .map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
1394 .collect();
1395 FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
1396 }
1397
1398 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1399 let extract = Arc::clone(&self.extract_sequence);
1400 let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
1401 let t: &T = dv
1402 .as_any_ref()
1403 .downcast_ref::<T>()
1404 .expect("merge-sequence extract: wrong element type");
1405 extract(t)
1406 });
1407 let typed_extract_fn = Arc::clone(&self.extract_sequence);
1410 let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
1411 let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
1412 StageSpec::merge_sequence(
1413 Arc::from(self.name()),
1414 shape.inlets(),
1415 shape.outlets(),
1416 self.inputs,
1417 extract_sequence,
1418 typed_extract,
1419 )
1420 }
1421
1422 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1423 #[derive(Clone)]
1424 struct Pending<T> {
1425 sequence: u64,
1426 elem: T,
1427 }
1428
1429 struct State<T> {
1430 next_sequence: u64,
1431 pending: Vec<Pending<T>>,
1432 completed: usize,
1433 pending_output: VecDeque<T>,
1434 }
1435
1436 fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
1437 where
1438 T: Clone + Send + 'static,
1439 {
1440 while let Some(index) = state
1441 .pending
1442 .iter()
1443 .position(|item| item.sequence == state.next_sequence)
1444 {
1445 let item = state.pending.remove(index);
1446 if state
1447 .pending
1448 .iter()
1449 .any(|other| other.sequence == state.next_sequence)
1450 {
1451 return Err(StreamError::Failed(format!(
1452 "duplicate sequence {} on merge sequence",
1453 state.next_sequence
1454 )));
1455 }
1456 state.pending_output.push_back(item.elem);
1457 state.next_sequence += 1;
1458 }
1459 Ok(())
1460 }
1461
1462 struct In<T: 'static> {
1463 inlet_id: PortId,
1464 inlet_index: usize,
1465 inlet: Inlet<T>,
1466 all_inlets: Vec<Inlet<T>>,
1467 outlet: Outlet<T>,
1468 extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
1469 state: Arc<Mutex<State<T>>>,
1470 }
1471
1472 impl<T> InHandler for In<T>
1473 where
1474 T: Clone + Send + 'static,
1475 {
1476 fn on_push(
1477 &mut self,
1478 logic: &mut GraphStageLogic,
1479 _inlet: AnyInlet,
1480 ) -> StreamResult<()> {
1481 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1482 downcast_datum(value, "grab", || {
1483 format!("inlet#{}", self.inlet_id.as_usize())
1484 })
1485 })?;
1486 {
1487 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1488 let sequence = (self.extract_sequence)(&elem);
1489 if sequence < state.next_sequence {
1490 return Err(StreamError::Failed(format!(
1491 "sequence regression from {} to {} on port {}",
1492 state.next_sequence, sequence, self.inlet_index
1493 )));
1494 }
1495 state.pending.push(Pending { sequence, elem });
1496 try_emit_pending(&mut state)?;
1497 }
1498 let next = if logic.is_available(&self.outlet) {
1499 self.state
1500 .lock()
1501 .expect("merge-sequence state poisoned")
1502 .pending_output
1503 .pop_front()
1504 } else {
1505 None
1506 };
1507 if let Some(value) = next {
1508 logic.push(&self.outlet, value)?;
1509 }
1510 let state = self.state.lock().expect("merge-sequence state poisoned");
1511 if state.completed == self.all_inlets.len()
1512 && state.pending.is_empty()
1513 && state.pending_output.is_empty()
1514 {
1515 logic.complete(&self.outlet)?;
1516 } else if logic.is_available(&self.outlet)
1517 && state.pending_output.is_empty()
1518 && state.pending.len() + state.completed == self.all_inlets.len()
1519 {
1520 return Err(StreamError::Failed(format!(
1521 "expected sequence {}, but all input ports have pushed or are complete",
1522 state.next_sequence
1523 )));
1524 }
1525 if !logic.has_been_pulled(&self.inlet) {
1526 logic.pull(&self.inlet)?;
1527 }
1528 Ok(())
1529 }
1530
1531 fn on_upstream_finish(
1532 &mut self,
1533 logic: &mut GraphStageLogic,
1534 _inlet: AnyInlet,
1535 ) -> StreamResult<()> {
1536 {
1537 let mut state = self.state.lock().expect("merge-sequence state poisoned");
1538 state.completed += 1;
1539 }
1540 let state = self.state.lock().expect("merge-sequence state poisoned");
1541 if state.completed == self.all_inlets.len()
1542 && state.pending.is_empty()
1543 && state.pending_output.is_empty()
1544 {
1545 logic.complete(&self.outlet)?;
1546 } else if logic.is_available(&self.outlet)
1547 && state.pending_output.is_empty()
1548 && state.pending.len() + state.completed == self.all_inlets.len()
1549 {
1550 return Err(StreamError::Failed(format!(
1551 "expected sequence {}, but all input ports have pushed or are complete",
1552 state.next_sequence
1553 )));
1554 }
1555 Ok(())
1556 }
1557 }
1558
1559 struct Out<T: 'static> {
1560 inlets: Vec<Inlet<T>>,
1561 outlet: Outlet<T>,
1562 state: Arc<Mutex<State<T>>>,
1563 }
1564
1565 impl<T> OutHandler for Out<T>
1566 where
1567 T: Clone + Send + 'static,
1568 {
1569 fn on_pull(
1570 &mut self,
1571 logic: &mut GraphStageLogic,
1572 _outlet: AnyOutlet,
1573 ) -> StreamResult<()> {
1574 let next = self
1575 .state
1576 .lock()
1577 .expect("merge-sequence state poisoned")
1578 .pending_output
1579 .pop_front();
1580 if let Some(value) = next {
1581 logic.push(&self.outlet, value)?;
1582 } else {
1583 let state = self.state.lock().expect("merge-sequence state poisoned");
1584 if state.completed == self.inlets.len() && state.pending.is_empty() {
1585 logic.complete(&self.outlet)?;
1586 } else if state.pending.len() + state.completed == self.inlets.len() {
1587 return Err(StreamError::Failed(format!(
1588 "expected sequence {}, but all input ports have pushed or are complete",
1589 state.next_sequence
1590 )));
1591 }
1592 }
1593 for inlet in &self.inlets {
1594 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1595 logic.pull(inlet)?;
1596 }
1597 }
1598 Ok(())
1599 }
1600 }
1601
1602 let inlets = shape.inlets_vec();
1603 let outlet = shape.outlet();
1604 let state = Arc::new(Mutex::new(State {
1605 next_sequence: 0,
1606 pending: Vec::new(),
1607 completed: 0,
1608 pending_output: VecDeque::new(),
1609 }));
1610 let mut logic = GraphStageLogic::new(shape);
1611 for (index, inlet) in inlets.iter().cloned().enumerate() {
1612 logic
1613 .set_handler(
1614 &inlet.clone(),
1615 Box::new(In {
1616 inlet_id: inlet.id(),
1617 inlet_index: index,
1618 inlet: inlet.clone(),
1619 all_inlets: inlets.clone(),
1620 outlet: outlet.clone(),
1621 extract_sequence: Arc::clone(&self.extract_sequence),
1622 state: Arc::clone(&state),
1623 }),
1624 )
1625 .unwrap();
1626 }
1627 logic
1628 .set_out_handler(
1629 &outlet.clone(),
1630 Box::new(Out {
1631 inlets,
1632 outlet: outlet.clone(),
1633 state,
1634 }),
1635 )
1636 .unwrap();
1637 logic
1638 }
1639}
1640
1641#[derive(Clone, Debug)]
1645pub struct MergeLatest<T: 'static> {
1646 inputs: usize,
1647 eager_complete: bool,
1648 _marker: PhantomData<fn() -> T>,
1649}
1650
1651impl<T: 'static> MergeLatest<T> {
1652 #[must_use]
1653 pub fn new(inputs: usize, eager_complete: bool) -> Self {
1654 assert!(
1655 inputs > 0,
1656 "merge-latest input count must be greater than zero"
1657 );
1658 Self {
1659 inputs,
1660 eager_complete,
1661 _marker: PhantomData,
1662 }
1663 }
1664}
1665
1666impl<T> GraphStage for MergeLatest<T>
1667where
1668 T: Clone + Send + 'static,
1669{
1670 type Shape = FanInShape<T, Vec<T>>;
1671
1672 fn name(&self) -> &str {
1673 "MergeLatest"
1674 }
1675
1676 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1677 let inlets = (0..self.inputs)
1678 .map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
1679 .collect();
1680 FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
1681 }
1682
1683 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1684 let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
1685 let snapshot: Vec<T> = values
1686 .iter()
1687 .map(|dv| {
1688 dv.as_any_ref()
1689 .downcast_ref::<T>()
1690 .cloned()
1691 .expect("merge-latest snapshot: wrong element type")
1692 })
1693 .collect();
1694 datum(snapshot)
1695 });
1696 #[allow(clippy::type_complexity)]
1699 let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
1700 Arc::new(move |slots: &[Option<T>]| {
1701 slots
1702 .iter()
1703 .map(|s| {
1704 s.clone()
1705 .expect("merge-latest typed snapshot: slot is None")
1706 })
1707 .collect()
1708 });
1709 let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
1710 StageSpec::merge_latest(
1711 Arc::from(self.name()),
1712 shape.inlets(),
1713 shape.outlets(),
1714 self.inputs,
1715 self.eager_complete,
1716 build_snapshot,
1717 typed_snapshot,
1718 )
1719 }
1720
1721 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1722 struct State<T> {
1723 latest: Vec<Option<T>>,
1724 seen: usize,
1725 completed: usize,
1726 pending: VecDeque<Vec<T>>,
1727 eager_complete: bool,
1728 }
1729
1730 struct In<T: 'static> {
1731 inlet_id: PortId,
1732 inlet_index: usize,
1733 inlet: Inlet<T>,
1734 all_inlets: Vec<Inlet<T>>,
1735 outlet: Outlet<Vec<T>>,
1736 state: Arc<Mutex<State<T>>>,
1737 }
1738
1739 impl<T> InHandler for In<T>
1740 where
1741 T: Clone + Send + 'static,
1742 {
1743 fn on_push(
1744 &mut self,
1745 logic: &mut GraphStageLogic,
1746 _inlet: AnyInlet,
1747 ) -> StreamResult<()> {
1748 let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
1749 downcast_datum(value, "grab", || {
1750 format!("inlet#{}", self.inlet_id.as_usize())
1751 })
1752 })?;
1753 {
1754 let mut state = self.state.lock().expect("merge-latest state poisoned");
1755 if state.latest[self.inlet_index].is_none() {
1756 state.seen += 1;
1757 }
1758 state.latest[self.inlet_index] = Some(elem);
1759 if state.seen == state.latest.len() {
1760 let snapshot = state
1761 .latest
1762 .iter()
1763 .map(|item| item.clone().expect("merge-latest seen"))
1764 .collect();
1765 state.pending.push_back(snapshot);
1766 }
1767 }
1768 let next = if logic.is_available(&self.outlet) {
1769 self.state
1770 .lock()
1771 .expect("merge-latest state poisoned")
1772 .pending
1773 .pop_front()
1774 } else {
1775 None
1776 };
1777 if let Some(value) = next {
1778 logic.push(&self.outlet, value)?;
1779 }
1780 if !logic.has_been_pulled(&self.inlet) {
1781 logic.pull(&self.inlet)?;
1782 }
1783 Ok(())
1784 }
1785
1786 fn on_upstream_finish(
1787 &mut self,
1788 logic: &mut GraphStageLogic,
1789 _inlet: AnyInlet,
1790 ) -> StreamResult<()> {
1791 let state = {
1792 let mut state = self.state.lock().expect("merge-latest state poisoned");
1793 state.completed += 1;
1794 (
1795 state.completed == self.all_inlets.len(),
1796 state.eager_complete,
1797 state.pending.is_empty(),
1798 )
1799 };
1800 if state.0 || (state.1 && state.2) {
1801 logic.complete(&self.outlet)?;
1802 }
1803 Ok(())
1804 }
1805 }
1806
1807 struct Out<T: 'static> {
1808 inlets: Vec<Inlet<T>>,
1809 outlet: Outlet<Vec<T>>,
1810 state: Arc<Mutex<State<T>>>,
1811 }
1812
1813 impl<T> OutHandler for Out<T>
1814 where
1815 T: Clone + Send + 'static,
1816 {
1817 fn on_pull(
1818 &mut self,
1819 logic: &mut GraphStageLogic,
1820 _outlet: AnyOutlet,
1821 ) -> StreamResult<()> {
1822 let next = self
1823 .state
1824 .lock()
1825 .expect("merge-latest state poisoned")
1826 .pending
1827 .pop_front();
1828 if let Some(value) = next {
1829 logic.push(&self.outlet, value)?;
1830 } else {
1831 let state = self.state.lock().expect("merge-latest state poisoned");
1832 if state.completed == self.inlets.len()
1833 || (state.eager_complete && state.completed > 0)
1834 {
1835 logic.complete(&self.outlet)?;
1836 }
1837 }
1838 for inlet in &self.inlets {
1839 if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
1840 logic.pull(inlet)?;
1841 }
1842 }
1843 Ok(())
1844 }
1845 }
1846
1847 let inlets = shape.inlets_vec();
1848 let outlet = shape.outlet();
1849 let state = Arc::new(Mutex::new(State {
1850 latest: vec![None; inlets.len()],
1851 seen: 0,
1852 completed: 0,
1853 pending: VecDeque::new(),
1854 eager_complete: self.eager_complete,
1855 }));
1856 let mut logic = GraphStageLogic::new(shape);
1857 for (index, inlet) in inlets.iter().cloned().enumerate() {
1858 logic
1859 .set_handler(
1860 &inlet.clone(),
1861 Box::new(In {
1862 inlet_id: inlet.id(),
1863 inlet_index: index,
1864 inlet: inlet.clone(),
1865 all_inlets: inlets.clone(),
1866 outlet: outlet.clone(),
1867 state: Arc::clone(&state),
1868 }),
1869 )
1870 .unwrap();
1871 }
1872 logic
1873 .set_out_handler(
1874 &outlet.clone(),
1875 Box::new(Out {
1876 inlets,
1877 outlet: outlet.clone(),
1878 state,
1879 }),
1880 )
1881 .unwrap();
1882 logic
1883 }
1884}
1885
1886#[derive(Clone)]
1890pub struct Partition<T: 'static> {
1891 outputs: usize,
1892 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
1893 eager_cancel: bool,
1894 _marker: PhantomData<fn() -> T>,
1895}
1896
1897impl<T: 'static> fmt::Debug for Partition<T> {
1898 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1899 f.debug_struct("Partition")
1900 .field("outputs", &self.outputs)
1901 .field("eager_cancel", &self.eager_cancel)
1902 .finish_non_exhaustive()
1903 }
1904}
1905
1906impl<T: 'static> Partition<T> {
1907 #[must_use]
1908 pub fn new<F>(outputs: usize, partitioner: F) -> Self
1909 where
1910 F: Fn(&T) -> usize + Send + Sync + 'static,
1911 {
1912 Self::new_with_eager_cancel(outputs, partitioner, false)
1913 }
1914
1915 #[must_use]
1916 pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
1917 where
1918 F: Fn(&T) -> usize + Send + Sync + 'static,
1919 {
1920 assert!(
1921 outputs > 0,
1922 "partition output count must be greater than zero"
1923 );
1924 Self {
1925 outputs,
1926 partitioner: Arc::new(partitioner),
1927 eager_cancel,
1928 _marker: PhantomData,
1929 }
1930 }
1931}
1932
1933impl<T> GraphStage for Partition<T>
1934where
1935 T: Clone + Send + 'static,
1936{
1937 type Shape = FanOutShape<T, T>;
1938
1939 fn name(&self) -> &str {
1940 "Partition"
1941 }
1942
1943 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
1944 let inlet = allocator.inlet("Partition.in");
1945 let outlets = (0..self.outputs)
1946 .map(|index| allocator.outlet(format!("Partition.out{index}")))
1947 .collect();
1948 FanOutShape::new(inlet, outlets)
1949 }
1950
1951 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
1952 let partitioner_clone = Arc::clone(&self.partitioner);
1953 let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
1954 let t: &T = dv
1955 .as_any_ref()
1956 .downcast_ref::<T>()
1957 .expect("partition: wrong element type");
1958 partitioner_clone(t)
1959 });
1960 let typed_partitioner_fn: Arc<dyn Fn(&T) -> usize + Send + Sync> =
1961 Arc::clone(&self.partitioner);
1962 let typed_partitioner: Arc<StageTypedPartitionFn> = Arc::new(typed_partitioner_fn);
1963 StageSpec::partition(
1964 Arc::from(self.name()),
1965 shape.inlets(),
1966 shape.outlets(),
1967 self.outputs,
1968 partitioner,
1969 typed_partitioner,
1970 self.eager_cancel,
1971 )
1972 }
1973
1974 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
1975 struct State<T> {
1976 pending: Option<(usize, T)>,
1977 upstream_closed: bool,
1978 live_outlets: usize,
1979 cancelled: Vec<bool>,
1980 eager_cancel: bool,
1981 }
1982
1983 fn any_live_demand<T>(
1984 logic: &GraphStageLogic,
1985 outlets: &[Outlet<T>],
1986 cancelled: &[bool],
1987 ) -> bool
1988 where
1989 T: Clone + Send + 'static,
1990 {
1991 outlets
1992 .iter()
1993 .enumerate()
1994 .any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
1995 }
1996
1997 struct In<T: 'static> {
1998 inlet_id: PortId,
1999 inlet: Inlet<T>,
2000 outlets: Vec<Outlet<T>>,
2001 partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
2002 state: Arc<Mutex<State<T>>>,
2003 }
2004
2005 impl<T> InHandler for In<T>
2006 where
2007 T: Clone + Send + 'static,
2008 {
2009 fn on_push(
2010 &mut self,
2011 logic: &mut GraphStageLogic,
2012 _inlet: AnyInlet,
2013 ) -> StreamResult<()> {
2014 let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
2015 downcast_datum(value, "grab", || {
2016 format!("inlet#{}", self.inlet_id.as_usize())
2017 })
2018 })?;
2019 let idx = (self.partitioner)(&item);
2020 if idx >= self.outlets.len() {
2021 return Err(StreamError::Failed(format!(
2022 "partitioner returned out-of-bounds index {idx} for {} outputs",
2023 self.outlets.len()
2024 )));
2025 }
2026 let mut pull_again = false;
2027 {
2028 let mut state = self.state.lock().expect("partition state poisoned");
2029 if state.cancelled[idx] {
2030 pull_again = !state.upstream_closed
2031 && any_live_demand(logic, &self.outlets, &state.cancelled);
2032 } else if logic.is_available(&self.outlets[idx]) {
2033 logic.push(&self.outlets[idx], item)?;
2034 pull_again = !state.upstream_closed
2035 && any_live_demand(logic, &self.outlets, &state.cancelled);
2036 } else {
2037 state.pending = Some((idx, item));
2038 }
2039 }
2040 if pull_again && !logic.has_been_pulled(&self.inlet) {
2041 logic.pull(&self.inlet)?;
2042 }
2043 Ok(())
2044 }
2045
2046 fn on_upstream_finish(
2047 &mut self,
2048 logic: &mut GraphStageLogic,
2049 _inlet: AnyInlet,
2050 ) -> StreamResult<()> {
2051 let complete_now = {
2052 let mut state = self.state.lock().expect("partition state poisoned");
2053 state.upstream_closed = true;
2054 state.pending.is_none()
2055 };
2056 if complete_now {
2057 for outlet in &self.outlets {
2058 if !logic.is_closed(outlet) {
2059 logic.complete(outlet)?;
2060 }
2061 }
2062 }
2063 Ok(())
2064 }
2065 }
2066
2067 struct Out<T: 'static> {
2068 index: usize,
2069 inlet: Inlet<T>,
2070 outlets: Vec<Outlet<T>>,
2071 state: Arc<Mutex<State<T>>>,
2072 }
2073
2074 impl<T> OutHandler for Out<T>
2075 where
2076 T: Clone + Send + 'static,
2077 {
2078 fn on_pull(
2079 &mut self,
2080 logic: &mut GraphStageLogic,
2081 _outlet: AnyOutlet,
2082 ) -> StreamResult<()> {
2083 let mut complete_now = false;
2084 let pending = {
2085 let mut state = self.state.lock().expect("partition state poisoned");
2086 if let Some((idx, _)) = &state.pending
2087 && *idx == self.index
2088 {
2089 state.pending.take()
2090 } else {
2091 None
2092 }
2093 };
2094 if let Some((_, item)) = pending {
2095 logic.push(&self.outlets[self.index], item)?;
2096 let state = self.state.lock().expect("partition state poisoned");
2097 if state.upstream_closed {
2098 complete_now = true;
2099 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2100 && !logic.has_been_pulled(&self.inlet)
2101 {
2102 logic.pull(&self.inlet)?;
2103 }
2104 } else {
2105 let state = self.state.lock().expect("partition state poisoned");
2106 if state.upstream_closed {
2107 complete_now = true;
2108 } else if any_live_demand(logic, &self.outlets, &state.cancelled)
2109 && !logic.has_been_pulled(&self.inlet)
2110 {
2111 logic.pull(&self.inlet)?;
2112 }
2113 }
2114 if complete_now {
2115 for outlet in &self.outlets {
2116 if !logic.is_closed(outlet) {
2117 logic.complete(outlet)?;
2118 }
2119 }
2120 }
2121 Ok(())
2122 }
2123
2124 fn on_downstream_finish(
2125 &mut self,
2126 logic: &mut GraphStageLogic,
2127 _outlet: AnyOutlet,
2128 ) -> StreamResult<()> {
2129 let (cancel_stage, clear_pending) = {
2130 let mut state = self.state.lock().expect("partition state poisoned");
2131 if state.cancelled[self.index] {
2132 return Ok(());
2133 }
2134 state.cancelled[self.index] = true;
2135 state.live_outlets -= 1;
2136 let clear_pending = state
2137 .pending
2138 .as_ref()
2139 .is_some_and(|(idx, _)| *idx == self.index);
2140 let cancel_stage = state.eager_cancel || state.live_outlets == 0;
2141 if clear_pending {
2142 state.pending = None;
2143 }
2144 (cancel_stage, clear_pending)
2145 };
2146 if cancel_stage {
2147 logic.complete_stage()?;
2148 } else if clear_pending
2149 && !logic.has_been_pulled(&self.inlet)
2150 && !logic.is_closed(&self.inlet)
2151 {
2152 let state = self.state.lock().expect("partition state poisoned");
2153 if any_live_demand(logic, &self.outlets, &state.cancelled) {
2154 logic.pull(&self.inlet)?;
2155 }
2156 }
2157 Ok(())
2158 }
2159 }
2160
2161 let inlet = shape.inlet();
2162 let outlets = shape.outlets_vec();
2163 let state = Arc::new(Mutex::new(State {
2164 pending: None,
2165 upstream_closed: false,
2166 live_outlets: outlets.len(),
2167 cancelled: vec![false; outlets.len()],
2168 eager_cancel: self.eager_cancel,
2169 }));
2170 let mut logic = GraphStageLogic::new(shape);
2171 logic
2172 .set_handler(
2173 &inlet,
2174 Box::new(In {
2175 inlet_id: inlet.id(),
2176 inlet: inlet.clone(),
2177 outlets: outlets.clone(),
2178 partitioner: Arc::clone(&self.partitioner),
2179 state: Arc::clone(&state),
2180 }),
2181 )
2182 .unwrap();
2183 for (index, outlet) in outlets.iter().cloned().enumerate() {
2184 logic
2185 .set_out_handler(
2186 &outlet,
2187 Box::new(Out {
2188 index,
2189 inlet: inlet.clone(),
2190 outlets: outlets.clone(),
2191 state: Arc::clone(&state),
2192 }),
2193 )
2194 .unwrap();
2195 }
2196 logic
2197 }
2198}
2199
2200#[derive(Clone, Debug)]
2203pub struct Unzip<A: 'static, B: 'static> {
2204 _marker: PhantomData<fn() -> (A, B)>,
2205}
2206
2207impl<A: 'static, B: 'static> Unzip<A, B> {
2208 #[must_use]
2209 pub fn new() -> Self {
2210 Self {
2211 _marker: PhantomData,
2212 }
2213 }
2214}
2215
2216impl<A: 'static, B: 'static> Default for Unzip<A, B> {
2217 fn default() -> Self {
2218 Self::new()
2219 }
2220}
2221
2222impl<A, B> GraphStage for Unzip<A, B>
2223where
2224 A: Clone + Send + 'static,
2225 B: Clone + Send + 'static,
2226{
2227 type Shape = FanOutShape2<(A, B), A, B>;
2228
2229 fn name(&self) -> &str {
2230 "Unzip"
2231 }
2232
2233 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2234 FanOutShape2::new(
2235 allocator.inlet("Unzip.in"),
2236 allocator.outlet("Unzip.out0"),
2237 allocator.outlet("Unzip.out1"),
2238 )
2239 }
2240
2241 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2242 let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
2243 let pair: (A, B) =
2244 downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
2245 (datum(pair.0), datum(pair.1))
2246 });
2247 #[allow(clippy::type_complexity)]
2252 let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
2253 Arc::new(|pair: (A, B)| pair);
2254 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2255 StageSpec::unzip(
2256 Arc::from(self.name()),
2257 shape.inlets(),
2258 shape.outlets(),
2259 split,
2260 typed_split,
2261 )
2262 }
2263
2264 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2265 UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
2266 }
2267}
2268
2269#[derive(Clone)]
2272pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
2273 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2274 _marker: PhantomData<fn(In) -> (Out0, Out1)>,
2275}
2276
2277impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
2278 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2279 f.debug_struct("UnzipWith").finish_non_exhaustive()
2280 }
2281}
2282
2283impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
2284 #[must_use]
2285 pub fn new<F>(split: F) -> Self
2286 where
2287 F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
2288 {
2289 Self {
2290 split: Arc::new(split),
2291 _marker: PhantomData,
2292 }
2293 }
2294}
2295
2296impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
2297where
2298 In: Clone + Send + 'static,
2299 Out0: Clone + Send + 'static,
2300 Out1: Clone + Send + 'static,
2301{
2302 type Shape = FanOutShape2<In, Out0, Out1>;
2303
2304 fn name(&self) -> &str {
2305 "UnzipWith"
2306 }
2307
2308 fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
2309 FanOutShape2::new(
2310 allocator.inlet("UnzipWith.in"),
2311 allocator.outlet("UnzipWith.out0"),
2312 allocator.outlet("UnzipWith.out1"),
2313 )
2314 }
2315
2316 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2317 let split_fn = Arc::clone(&self.split);
2318 let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
2319 let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
2320 .expect("unzip-with: wrong element type");
2321 let (out0, out1) = split_fn(value);
2322 (datum(out0), datum(out1))
2323 });
2324 let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
2327 let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
2328 StageSpec::unzip(
2329 Arc::from(self.name()),
2330 shape.inlets(),
2331 shape.outlets(),
2332 split,
2333 typed_split,
2334 )
2335 }
2336
2337 fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
2338 struct State {
2339 left_open: bool,
2340 right_open: bool,
2341 upstream_closed: bool,
2342 }
2343
2344 struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
2345 inlet_id: PortId,
2346 inlet: Inlet<In>,
2347 out0: Outlet<Out0>,
2348 out1: Outlet<Out1>,
2349 split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
2350 state: Arc<Mutex<State>>,
2351 }
2352
2353 impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
2354 where
2355 In: Clone + Send + 'static,
2356 Out0: Clone + Send + 'static,
2357 Out1: Clone + Send + 'static,
2358 {
2359 fn on_push(
2360 &mut self,
2361 logic: &mut GraphStageLogic,
2362 _inlet: AnyInlet,
2363 ) -> StreamResult<()> {
2364 let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
2365 downcast_datum(value, "grab", || {
2366 format!("inlet#{}", self.inlet_id.as_usize())
2367 })
2368 })?;
2369 let (left, right) = (self.split)(value);
2370 let state = self.state.lock().expect("unzip-with state poisoned");
2371 if state.left_open {
2372 logic.push(&self.out0, left)?;
2373 }
2374 if state.right_open {
2375 logic.push(&self.out1, right)?;
2376 }
2377 drop(state);
2378 let state = self.state.lock().expect("unzip-with state poisoned");
2379 let left_ready = !state.left_open || logic.is_available(&self.out0);
2380 let right_ready = !state.right_open || logic.is_available(&self.out1);
2381 if (state.left_open || state.right_open)
2382 && left_ready
2383 && right_ready
2384 && !logic.has_been_pulled(&self.inlet)
2385 {
2386 logic.pull(&self.inlet)?;
2387 }
2388 Ok(())
2389 }
2390
2391 fn on_upstream_finish(
2392 &mut self,
2393 logic: &mut GraphStageLogic,
2394 _inlet: AnyInlet,
2395 ) -> StreamResult<()> {
2396 self.state
2397 .lock()
2398 .expect("unzip-with state poisoned")
2399 .upstream_closed = true;
2400 if !logic.is_closed(&self.out0) {
2401 logic.complete(&self.out0)?;
2402 }
2403 if !logic.is_closed(&self.out1) {
2404 logic.complete(&self.out1)?;
2405 }
2406 Ok(())
2407 }
2408 }
2409
2410 struct Out<In: 'static, Out0: 'static, Out1: 'static> {
2411 is_left: bool,
2412 inlet: Inlet<In>,
2413 out0: Outlet<Out0>,
2414 out1: Outlet<Out1>,
2415 state: Arc<Mutex<State>>,
2416 }
2417
2418 impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
2419 where
2420 In: Clone + Send + 'static,
2421 Out0: Clone + Send + 'static,
2422 Out1: Clone + Send + 'static,
2423 {
2424 fn on_pull(
2425 &mut self,
2426 logic: &mut GraphStageLogic,
2427 _outlet: AnyOutlet,
2428 ) -> StreamResult<()> {
2429 let state = self.state.lock().expect("unzip-with state poisoned");
2430 let left_ready = !state.left_open || logic.is_available(&self.out0);
2431 let right_ready = !state.right_open || logic.is_available(&self.out1);
2432 if state.upstream_closed {
2433 drop(state);
2434 if !logic.is_closed(&self.out0) {
2435 logic.complete(&self.out0)?;
2436 }
2437 if !logic.is_closed(&self.out1) {
2438 logic.complete(&self.out1)?;
2439 }
2440 } else if (state.left_open || state.right_open)
2441 && left_ready
2442 && right_ready
2443 && !logic.has_been_pulled(&self.inlet)
2444 {
2445 drop(state);
2446 logic.pull(&self.inlet)?;
2447 }
2448 Ok(())
2449 }
2450
2451 fn on_downstream_finish(
2452 &mut self,
2453 logic: &mut GraphStageLogic,
2454 _outlet: AnyOutlet,
2455 ) -> StreamResult<()> {
2456 let mut state = self.state.lock().expect("unzip-with state poisoned");
2457 if self.is_left {
2458 state.left_open = false;
2459 } else {
2460 state.right_open = false;
2461 }
2462 if !state.left_open && !state.right_open {
2463 logic.complete_stage()?;
2464 return Ok(());
2465 }
2466 let left_ready = !state.left_open || logic.is_available(&self.out0);
2467 let right_ready = !state.right_open || logic.is_available(&self.out1);
2468 if !state.upstream_closed
2469 && (state.left_open || state.right_open)
2470 && left_ready
2471 && right_ready
2472 && !logic.has_been_pulled(&self.inlet)
2473 {
2474 logic.pull(&self.inlet)?;
2475 }
2476 Ok(())
2477 }
2478 }
2479
2480 let inlet = shape.inlet();
2481 let out0 = shape.out0();
2482 let out1 = shape.out1();
2483 let state = Arc::new(Mutex::new(State {
2484 left_open: true,
2485 right_open: true,
2486 upstream_closed: false,
2487 }));
2488 let mut logic = GraphStageLogic::new(shape);
2489 logic
2490 .set_handler(
2491 &inlet,
2492 Box::new(InHandlerState {
2493 inlet_id: inlet.id(),
2494 inlet: inlet.clone(),
2495 out0: out0.clone(),
2496 out1: out1.clone(),
2497 split: Arc::clone(&self.split),
2498 state: Arc::clone(&state),
2499 }),
2500 )
2501 .unwrap();
2502 logic
2503 .set_out_handler(
2504 &out0,
2505 Box::new(Out {
2506 is_left: true,
2507 inlet: inlet.clone(),
2508 out0: out0.clone(),
2509 out1: out1.clone(),
2510 state: Arc::clone(&state),
2511 }),
2512 )
2513 .unwrap();
2514 logic
2515 .set_out_handler(
2516 &out1.clone(),
2517 Box::new(Out {
2518 is_left: false,
2519 inlet: inlet.clone(),
2520 out0: out0.clone(),
2521 out1: out1.clone(),
2522 state,
2523 }),
2524 )
2525 .unwrap();
2526 logic
2527 }
2528}
2529
2530#[derive(Clone, Debug)]
2533pub struct AsyncBoundary<T: 'static> {
2534 _marker: PhantomData<fn() -> T>,
2535}
2536
2537impl<T: 'static> AsyncBoundary<T> {
2538 #[must_use]
2539 pub fn new() -> Self {
2540 Self {
2541 _marker: PhantomData,
2542 }
2543 }
2544}
2545
2546impl<T: 'static> Default for AsyncBoundary<T> {
2547 fn default() -> Self {
2548 Self::new()
2549 }
2550}
2551
2552impl<T> GraphStage for AsyncBoundary<T>
2553where
2554 T: Clone + Send + 'static,
2555{
2556 type Shape = FlowShape<T, T>;
2557
2558 fn name(&self) -> &str {
2559 "AsyncBoundary"
2560 }
2561
2562 fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
2563 let first_id = next_port_id_block(2);
2564 FlowShape::new(
2565 Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
2566 Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
2567 )
2568 }
2569
2570 fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
2571 self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
2572 }
2573
2574 fn stage_spec_with_ports(
2575 &self,
2576 _shape: &Self::Shape,
2577 inlets: Vec<AnyInlet>,
2578 outlets: Vec<AnyOutlet>,
2579 ) -> StageSpec {
2580 StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
2581 }
2582}