use super::*;
#[derive(Clone, Debug)]
pub struct Identity<T: 'static> {
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Identity<T> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T: 'static> Default for Identity<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> GraphStage for Identity<T>
where
T: Clone + Send + 'static,
{
type Shape = FlowShape<T, T>;
fn name(&self) -> &str {
"Identity"
}
fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
let first_id = next_port_id_block(2);
FlowShape::new(
Inlet::with_arc_name(first_id, identity_inlet_name()),
Outlet::with_arc_name(first_id.offset(1), identity_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
StageSpec::identity(shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::identity(inlets, outlets)
}
}
#[derive(Clone)]
pub struct MapStage<In: 'static, Out: 'static> {
f: Arc<dyn Fn(In) -> Out + Send + Sync>,
_marker: PhantomData<fn(In) -> Out>,
}
impl<In: 'static, Out: 'static> fmt::Debug for MapStage<In, Out> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MapStage")
.field("name", &"Map")
.finish_non_exhaustive()
}
}
impl<In: 'static, Out: 'static> MapStage<In, Out> {
#[must_use]
pub fn new<F>(f: F) -> Self
where
F: Fn(In) -> Out + Send + Sync + 'static,
{
Self {
f: Arc::new(f),
_marker: PhantomData,
}
}
}
impl<In, Out> GraphStage for MapStage<In, Out>
where
In: Clone + Send + 'static,
Out: Clone + Send + 'static,
{
type Shape = FlowShape<In, Out>;
fn name(&self) -> &str {
"Map"
}
fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
let first_id = next_port_id_block(2);
FlowShape::new(
Inlet::with_arc_name(first_id, map_inlet_name()),
Outlet::with_arc_name(first_id.offset(1), map_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
let f = Arc::clone(&self.f);
let typed = Arc::new(Arc::clone(&self.f)) as Arc<StageTypedMapFn>;
let mapper = Arc::new(move |value: DatumValue| {
let value: In = downcast_datum(value, "map", || "Map.in")?;
Ok(datum(f(value)))
});
StageSpec::map(map_stage_name(), inlets, outlets, mapper, typed)
}
}
#[derive(Clone, Debug)]
pub struct Broadcast<T: 'static> {
outputs: usize,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Broadcast<T> {
#[must_use]
pub fn new(outputs: usize) -> Self {
assert!(
outputs > 0,
"broadcast output count must be greater than zero"
);
Self {
outputs,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Broadcast<T>
where
T: Clone + Send + 'static,
{
type Shape = FanOutShape<T, T>;
fn name(&self) -> &str {
"Broadcast"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlet = allocator.inlet_arc(broadcast_inlet_name());
let outlets = (0..self.outputs)
.map(|index| allocator.outlet(format!("Broadcast.out{index}")))
.collect();
FanOutShape::new(inlet, outlets)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::broadcast(broadcast_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct Balance<T: 'static> {
outputs: usize,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Balance<T> {
#[must_use]
pub fn new(outputs: usize) -> Self {
assert!(
outputs > 0,
"balance output count must be greater than zero"
);
Self {
outputs,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Balance<T>
where
T: Clone + Send + 'static,
{
type Shape = FanOutShape<T, T>;
fn name(&self) -> &str {
"Balance"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlet = allocator.inlet_arc(balance_inlet_name());
let outlets = (0..self.outputs)
.map(|index| allocator.outlet(format!("Balance.out{index}")))
.collect();
FanOutShape::new(inlet, outlets)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::balance(balance_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct Merge<T: 'static> {
inputs: usize,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Merge<T> {
#[must_use]
pub fn new(inputs: usize) -> Self {
assert!(inputs > 0, "merge input count must be greater than zero");
Self {
inputs,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Merge<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"Merge"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.inputs)
.map(|index| allocator.inlet(format!("Merge.in{index}")))
.collect();
FanInShape::new(inlets, allocator.outlet_arc(merge_outlet_name()))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::merge(merge_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct Concat<T: 'static> {
inputs: usize,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Concat<T> {
#[must_use]
pub fn new(inputs: usize) -> Self {
assert!(inputs > 1, "concat input count must be greater than one");
Self {
inputs,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Concat<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"Concat"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.inputs)
.map(|index| allocator.inlet(format!("Concat.in{index}")))
.collect();
FanInShape::new(inlets, allocator.outlet_arc(concat_outlet_name()))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::concat(concat_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct OrElse<T: 'static> {
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> OrElse<T> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T: 'static> Default for OrElse<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> GraphStage for OrElse<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"OrElse"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = vec![
allocator.inlet_arc(or_else_primary_name()),
allocator.inlet_arc(or_else_secondary_name()),
];
FanInShape::new(inlets, allocator.outlet_arc(or_else_outlet_name()))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::or_else(or_else_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct Interleave<T: 'static> {
inputs: usize,
segment_size: usize,
eager_close: bool,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> Interleave<T> {
#[must_use]
pub fn new(inputs: usize, segment_size: usize) -> Self {
Self::new_with_eager_close(inputs, segment_size, false)
}
#[must_use]
pub fn new_with_eager_close(inputs: usize, segment_size: usize, eager_close: bool) -> Self {
assert!(
inputs > 1,
"interleave input count must be greater than one"
);
assert!(
segment_size > 0,
"interleave segment size must be greater than zero"
);
Self {
inputs,
segment_size,
eager_close,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Interleave<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"Interleave"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.inputs)
.map(|index| allocator.inlet(format!("Interleave.in{index}")))
.collect();
FanInShape::new(inlets, allocator.outlet_arc(interleave_outlet_name()))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::interleave(
interleave_stage_name(),
inlets,
outlets,
self.segment_size,
self.eager_close,
)
}
}
#[derive(Clone, Debug)]
pub struct MergePreferred<T: 'static> {
secondary_ports: usize,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> MergePreferred<T> {
#[must_use]
pub fn new(secondary_ports: usize) -> Self {
assert!(
secondary_ports > 0,
"merge-preferred secondary input count must be greater than zero"
);
Self {
secondary_ports,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for MergePreferred<T>
where
T: Clone + Send + 'static,
{
type Shape = MergePreferredShape<T>;
fn name(&self) -> &str {
"MergePreferred"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let preferred = allocator.inlet_arc(merge_preferred_preferred_name());
let secondary = (0..self.secondary_ports)
.map(|index| allocator.inlet(format!("MergePreferred.in{index}")))
.collect();
MergePreferredShape::new(
preferred,
secondary,
allocator.outlet_arc(merge_preferred_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::merge_preferred(merge_preferred_stage_name(), inlets, outlets)
}
}
#[derive(Clone, Debug)]
pub struct MergePrioritized<T: 'static> {
weights: Vec<usize>,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> MergePrioritized<T> {
#[must_use]
pub fn new(weights: Vec<usize>) -> Self {
assert!(!weights.is_empty(), "prioritized merge must have inputs");
assert!(
weights.iter().all(|weight| *weight > 0),
"prioritized merge weights must be greater than zero"
);
Self {
weights,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for MergePrioritized<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"MergePrioritized"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.weights.len())
.map(|index| allocator.inlet(format!("MergePrioritized.in{index}")))
.collect();
FanInShape::new(
inlets,
allocator.outlet_arc(merge_prioritized_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::merge_prioritized(
merge_prioritized_stage_name(),
inlets,
outlets,
self.weights.clone(),
)
}
}
#[derive(Clone, Debug)]
pub struct Zip<Left: 'static, Right: 'static> {
_marker: PhantomData<fn() -> (Left, Right)>,
}
impl<Left: 'static, Right: 'static> Zip<Left, Right> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<Left: 'static, Right: 'static> Default for Zip<Left, Right> {
fn default() -> Self {
Self::new()
}
}
impl<Left, Right> GraphStage for Zip<Left, Right>
where
Left: Clone + Send + 'static,
Right: Clone + Send + 'static,
{
type Shape = ZipShape<Left, Right>;
fn name(&self) -> &str {
"Zip"
}
fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
let first_id = next_port_id_block(3);
ZipShape::new(
Inlet::with_arc_name(first_id, zip_in0_name()),
Inlet::with_arc_name(first_id.offset(1), zip_in1_name()),
Outlet::with_arc_name(first_id.offset(2), zip_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
let zip = Arc::new(move |left: DatumValue, right: DatumValue| {
let left: Left = downcast_datum(left, "zip", || "Zip.in0")?;
let right: Right = downcast_datum(right, "zip", || "Zip.in1")?;
Ok(datum((left, right)))
});
StageSpec::zip(zip_stage_name(), inlets, outlets, zip)
}
}
#[derive(Clone, Debug)]
pub struct MergeSorted<T: 'static> {
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> MergeSorted<T> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T: 'static> Default for MergeSorted<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> GraphStage for MergeSorted<T>
where
T: Clone + Ord + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"MergeSorted"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = vec![
allocator.inlet("MergeSorted.in0"),
allocator.inlet("MergeSorted.in1"),
];
FanInShape::new(inlets, allocator.outlet("MergeSorted.out"))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let compare = Arc::new(
move |a: &DatumValue, b: &DatumValue| -> std::cmp::Ordering {
let a_t: &T = a
.as_any_ref()
.downcast_ref::<T>()
.expect("merge-sorted compare: wrong element type");
let b_t: &T = b
.as_any_ref()
.downcast_ref::<T>()
.expect("merge-sorted compare: wrong element type");
a_t.cmp(b_t)
},
);
StageSpec::merge_sorted(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
compare,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
struct State<T> {
left: VecDeque<T>,
right: VecDeque<T>,
left_closed: bool,
right_closed: bool,
pending: VecDeque<T>,
}
impl<T> Default for State<T> {
fn default() -> Self {
Self {
left: VecDeque::new(),
right: VecDeque::new(),
left_closed: false,
right_closed: false,
pending: VecDeque::new(),
}
}
}
fn maybe_queue_output<T>(state: &mut State<T>) -> bool
where
T: Clone + Ord,
{
let next = match (state.left.front(), state.right.front()) {
(Some(left), Some(right)) => {
if left <= right {
state.left.pop_front()
} else {
state.right.pop_front()
}
}
(Some(_), None) if state.right_closed => state.left.pop_front(),
(None, Some(_)) if state.left_closed => state.right.pop_front(),
_ => None,
};
if let Some(value) = next {
state.pending.push_back(value);
true
} else {
false
}
}
fn maybe_complete<T>(
logic: &mut GraphStageLogic,
outlet: &Outlet<T>,
state: &State<T>,
) -> StreamResult<()>
where
T: Clone + Send + 'static,
{
if state.left_closed
&& state.right_closed
&& state.left.is_empty()
&& state.right.is_empty()
&& state.pending.is_empty()
&& !logic.is_closed(outlet)
{
logic.complete(outlet)?;
}
Ok(())
}
fn maybe_pull<T>(
logic: &mut GraphStageLogic,
left: &Inlet<T>,
right: &Inlet<T>,
state: &State<T>,
) -> StreamResult<()>
where
T: Clone + Send + 'static,
{
if state.left.is_empty() && !state.left_closed && !logic.has_been_pulled(left) {
logic.pull(left)?;
}
if state.right.is_empty() && !state.right_closed && !logic.has_been_pulled(right) {
logic.pull(right)?;
}
Ok(())
}
fn maybe_drain<T>(
logic: &mut GraphStageLogic,
outlet: &Outlet<T>,
state: &Arc<Mutex<State<T>>>,
) -> StreamResult<()>
where
T: Clone + Ord + Send + 'static,
{
let next = if logic.is_available(outlet) {
state
.lock()
.expect("merge-sorted state poisoned")
.pending
.pop_front()
} else {
None
};
if let Some(value) = next {
logic.push(outlet, value)?;
}
Ok(())
}
struct In<T: 'static> {
inlet_id: PortId,
left: Inlet<T>,
right: Inlet<T>,
outlet: Outlet<T>,
state: Arc<Mutex<State<T>>>,
}
impl<T> InHandler for In<T>
where
T: Clone + Ord + Send + 'static,
{
fn on_push(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let value: T = logic.grab_datum(self.inlet_id).and_then(|value| {
downcast_datum(value, "grab", || {
format!("inlet#{}", self.inlet_id.as_usize())
})
})?;
{
let mut state = self.state.lock().expect("merge-sorted state poisoned");
if self.inlet_id == self.left.id() {
state.left.push_back(value);
} else {
state.right.push_back(value);
}
while maybe_queue_output(&mut state) {}
}
maybe_drain(logic, &self.outlet, &self.state)?;
let state = self.state.lock().expect("merge-sorted state poisoned");
maybe_complete(logic, &self.outlet, &state)?;
maybe_pull(logic, &self.left, &self.right, &state)
}
fn on_upstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
{
let mut state = self.state.lock().expect("merge-sorted state poisoned");
if self.inlet_id == self.left.id() {
state.left_closed = true;
} else {
state.right_closed = true;
}
while maybe_queue_output(&mut state) {}
}
maybe_drain(logic, &self.outlet, &self.state)?;
let state = self.state.lock().expect("merge-sorted state poisoned");
maybe_complete(logic, &self.outlet, &state)?;
maybe_pull(logic, &self.left, &self.right, &state)
}
}
struct Out<T: 'static> {
left: Inlet<T>,
right: Inlet<T>,
outlet: Outlet<T>,
state: Arc<Mutex<State<T>>>,
}
impl<T> OutHandler for Out<T>
where
T: Clone + Ord + Send + 'static,
{
fn on_pull(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
maybe_drain(logic, &self.outlet, &self.state)?;
let state = self.state.lock().expect("merge-sorted state poisoned");
maybe_complete(logic, &self.outlet, &state)?;
maybe_pull(logic, &self.left, &self.right, &state)
}
}
let state = Arc::new(Mutex::new(State::<T>::default()));
let left = shape.inlet(0).expect("merge-sorted left inlet");
let right = shape.inlet(1).expect("merge-sorted right inlet");
let outlet = shape.outlet();
let mut logic = GraphStageLogic::new(shape);
logic
.set_handler(
&left,
Box::new(In {
inlet_id: left.id(),
left: left.clone(),
right: right.clone(),
outlet: outlet.clone(),
state: Arc::clone(&state),
}),
)
.unwrap();
logic
.set_handler(
&right,
Box::new(In {
inlet_id: right.id(),
left: left.clone(),
right: right.clone(),
outlet: outlet.clone(),
state: Arc::clone(&state),
}),
)
.unwrap();
logic
.set_out_handler(
&outlet.clone(),
Box::new(Out {
left,
right,
outlet: outlet.clone(),
state,
}),
)
.unwrap();
logic
}
}
#[derive(Clone)]
pub struct MergeSequence<T: 'static> {
inputs: usize,
extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> fmt::Debug for MergeSequence<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("MergeSequence")
.field("inputs", &self.inputs)
.finish_non_exhaustive()
}
}
impl<T: 'static> MergeSequence<T> {
#[must_use]
pub fn new<F>(inputs: usize, extract_sequence: F) -> Self
where
F: Fn(&T) -> u64 + Send + Sync + 'static,
{
assert!(
inputs > 1,
"merge sequence input count must be greater than one"
);
Self {
inputs,
extract_sequence: Arc::new(extract_sequence),
_marker: PhantomData,
}
}
}
impl<T> GraphStage for MergeSequence<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, T>;
fn name(&self) -> &str {
"MergeSequence"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.inputs)
.map(|index| allocator.inlet(format!("MergeSequence.in{index}")))
.collect();
FanInShape::new(inlets, allocator.outlet("MergeSequence.out"))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let extract = Arc::clone(&self.extract_sequence);
let extract_sequence = Arc::new(move |dv: &DatumValue| -> u64 {
let t: &T = dv
.as_any_ref()
.downcast_ref::<T>()
.expect("merge-sequence extract: wrong element type");
extract(t)
});
let typed_extract_fn = Arc::clone(&self.extract_sequence);
let typed_extract: Arc<dyn Fn(&T) -> u64 + Send + Sync> = typed_extract_fn;
let typed_extract: Arc<StageTypedSequenceFn> = Arc::new(typed_extract);
StageSpec::merge_sequence(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
self.inputs,
extract_sequence,
typed_extract,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
#[derive(Clone)]
struct Pending<T> {
sequence: u64,
elem: T,
}
struct State<T> {
next_sequence: u64,
pending: Vec<Pending<T>>,
completed: usize,
pending_output: VecDeque<T>,
}
fn try_emit_pending<T>(state: &mut State<T>) -> StreamResult<()>
where
T: Clone + Send + 'static,
{
while let Some(index) = state
.pending
.iter()
.position(|item| item.sequence == state.next_sequence)
{
let item = state.pending.remove(index);
if state
.pending
.iter()
.any(|other| other.sequence == state.next_sequence)
{
return Err(StreamError::Failed(format!(
"duplicate sequence {} on merge sequence",
state.next_sequence
)));
}
state.pending_output.push_back(item.elem);
state.next_sequence += 1;
}
Ok(())
}
struct In<T: 'static> {
inlet_id: PortId,
inlet_index: usize,
inlet: Inlet<T>,
all_inlets: Vec<Inlet<T>>,
outlet: Outlet<T>,
extract_sequence: Arc<dyn Fn(&T) -> u64 + Send + Sync>,
state: Arc<Mutex<State<T>>>,
}
impl<T> InHandler for In<T>
where
T: Clone + Send + 'static,
{
fn on_push(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
downcast_datum(value, "grab", || {
format!("inlet#{}", self.inlet_id.as_usize())
})
})?;
{
let mut state = self.state.lock().expect("merge-sequence state poisoned");
let sequence = (self.extract_sequence)(&elem);
if sequence < state.next_sequence {
return Err(StreamError::Failed(format!(
"sequence regression from {} to {} on port {}",
state.next_sequence, sequence, self.inlet_index
)));
}
state.pending.push(Pending { sequence, elem });
try_emit_pending(&mut state)?;
}
let next = if logic.is_available(&self.outlet) {
self.state
.lock()
.expect("merge-sequence state poisoned")
.pending_output
.pop_front()
} else {
None
};
if let Some(value) = next {
logic.push(&self.outlet, value)?;
}
let state = self.state.lock().expect("merge-sequence state poisoned");
if state.completed == self.all_inlets.len()
&& state.pending.is_empty()
&& state.pending_output.is_empty()
{
logic.complete(&self.outlet)?;
} else if logic.is_available(&self.outlet)
&& state.pending_output.is_empty()
&& state.pending.len() + state.completed == self.all_inlets.len()
{
return Err(StreamError::Failed(format!(
"expected sequence {}, but all input ports have pushed or are complete",
state.next_sequence
)));
}
if !logic.has_been_pulled(&self.inlet) {
logic.pull(&self.inlet)?;
}
Ok(())
}
fn on_upstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
{
let mut state = self.state.lock().expect("merge-sequence state poisoned");
state.completed += 1;
}
let state = self.state.lock().expect("merge-sequence state poisoned");
if state.completed == self.all_inlets.len()
&& state.pending.is_empty()
&& state.pending_output.is_empty()
{
logic.complete(&self.outlet)?;
} else if logic.is_available(&self.outlet)
&& state.pending_output.is_empty()
&& state.pending.len() + state.completed == self.all_inlets.len()
{
return Err(StreamError::Failed(format!(
"expected sequence {}, but all input ports have pushed or are complete",
state.next_sequence
)));
}
Ok(())
}
}
struct Out<T: 'static> {
inlets: Vec<Inlet<T>>,
outlet: Outlet<T>,
state: Arc<Mutex<State<T>>>,
}
impl<T> OutHandler for Out<T>
where
T: Clone + Send + 'static,
{
fn on_pull(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let next = self
.state
.lock()
.expect("merge-sequence state poisoned")
.pending_output
.pop_front();
if let Some(value) = next {
logic.push(&self.outlet, value)?;
} else {
let state = self.state.lock().expect("merge-sequence state poisoned");
if state.completed == self.inlets.len() && state.pending.is_empty() {
logic.complete(&self.outlet)?;
} else if state.pending.len() + state.completed == self.inlets.len() {
return Err(StreamError::Failed(format!(
"expected sequence {}, but all input ports have pushed or are complete",
state.next_sequence
)));
}
}
for inlet in &self.inlets {
if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
logic.pull(inlet)?;
}
}
Ok(())
}
}
let inlets = shape.inlets_vec();
let outlet = shape.outlet();
let state = Arc::new(Mutex::new(State {
next_sequence: 0,
pending: Vec::new(),
completed: 0,
pending_output: VecDeque::new(),
}));
let mut logic = GraphStageLogic::new(shape);
for (index, inlet) in inlets.iter().cloned().enumerate() {
logic
.set_handler(
&inlet.clone(),
Box::new(In {
inlet_id: inlet.id(),
inlet_index: index,
inlet: inlet.clone(),
all_inlets: inlets.clone(),
outlet: outlet.clone(),
extract_sequence: Arc::clone(&self.extract_sequence),
state: Arc::clone(&state),
}),
)
.unwrap();
}
logic
.set_out_handler(
&outlet.clone(),
Box::new(Out {
inlets,
outlet: outlet.clone(),
state,
}),
)
.unwrap();
logic
}
}
#[derive(Clone, Debug)]
pub struct MergeLatest<T: 'static> {
inputs: usize,
eager_complete: bool,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> MergeLatest<T> {
#[must_use]
pub fn new(inputs: usize, eager_complete: bool) -> Self {
assert!(
inputs > 0,
"merge-latest input count must be greater than zero"
);
Self {
inputs,
eager_complete,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for MergeLatest<T>
where
T: Clone + Send + 'static,
{
type Shape = FanInShape<T, Vec<T>>;
fn name(&self) -> &str {
"MergeLatest"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlets = (0..self.inputs)
.map(|index| allocator.inlet(format!("MergeLatest.in{index}")))
.collect();
FanInShape::new(inlets, allocator.outlet("MergeLatest.out"))
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let build_snapshot = Arc::new(move |values: &[&DatumValue]| -> DatumValue {
let snapshot: Vec<T> = values
.iter()
.map(|dv| {
dv.as_any_ref()
.downcast_ref::<T>()
.cloned()
.expect("merge-latest snapshot: wrong element type")
})
.collect();
datum(snapshot)
});
#[allow(clippy::type_complexity)]
let typed_snapshot_fn: Arc<dyn Fn(&[Option<T>]) -> Vec<T> + Send + Sync> =
Arc::new(move |slots: &[Option<T>]| {
slots
.iter()
.map(|s| {
s.clone()
.expect("merge-latest typed snapshot: slot is None")
})
.collect()
});
let typed_snapshot: Arc<StageTypedSnapshotFn> = Arc::new(typed_snapshot_fn);
StageSpec::merge_latest(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
self.inputs,
self.eager_complete,
build_snapshot,
typed_snapshot,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
struct State<T> {
latest: Vec<Option<T>>,
seen: usize,
completed: usize,
pending: VecDeque<Vec<T>>,
eager_complete: bool,
}
struct In<T: 'static> {
inlet_id: PortId,
inlet_index: usize,
inlet: Inlet<T>,
all_inlets: Vec<Inlet<T>>,
outlet: Outlet<Vec<T>>,
state: Arc<Mutex<State<T>>>,
}
impl<T> InHandler for In<T>
where
T: Clone + Send + 'static,
{
fn on_push(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let elem: T = logic.grab_datum(self.inlet_id).and_then(|value| {
downcast_datum(value, "grab", || {
format!("inlet#{}", self.inlet_id.as_usize())
})
})?;
{
let mut state = self.state.lock().expect("merge-latest state poisoned");
if state.latest[self.inlet_index].is_none() {
state.seen += 1;
}
state.latest[self.inlet_index] = Some(elem);
if state.seen == state.latest.len() {
let snapshot = state
.latest
.iter()
.map(|item| item.clone().expect("merge-latest seen"))
.collect();
state.pending.push_back(snapshot);
}
}
let next = if logic.is_available(&self.outlet) {
self.state
.lock()
.expect("merge-latest state poisoned")
.pending
.pop_front()
} else {
None
};
if let Some(value) = next {
logic.push(&self.outlet, value)?;
}
if !logic.has_been_pulled(&self.inlet) {
logic.pull(&self.inlet)?;
}
Ok(())
}
fn on_upstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let state = {
let mut state = self.state.lock().expect("merge-latest state poisoned");
state.completed += 1;
(
state.completed == self.all_inlets.len(),
state.eager_complete,
state.pending.is_empty(),
)
};
if state.0 || (state.1 && state.2) {
logic.complete(&self.outlet)?;
}
Ok(())
}
}
struct Out<T: 'static> {
inlets: Vec<Inlet<T>>,
outlet: Outlet<Vec<T>>,
state: Arc<Mutex<State<T>>>,
}
impl<T> OutHandler for Out<T>
where
T: Clone + Send + 'static,
{
fn on_pull(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let next = self
.state
.lock()
.expect("merge-latest state poisoned")
.pending
.pop_front();
if let Some(value) = next {
logic.push(&self.outlet, value)?;
} else {
let state = self.state.lock().expect("merge-latest state poisoned");
if state.completed == self.inlets.len()
|| (state.eager_complete && state.completed > 0)
{
logic.complete(&self.outlet)?;
}
}
for inlet in &self.inlets {
if !logic.has_been_pulled(inlet) && !logic.is_closed(inlet) {
logic.pull(inlet)?;
}
}
Ok(())
}
}
let inlets = shape.inlets_vec();
let outlet = shape.outlet();
let state = Arc::new(Mutex::new(State {
latest: vec![None; inlets.len()],
seen: 0,
completed: 0,
pending: VecDeque::new(),
eager_complete: self.eager_complete,
}));
let mut logic = GraphStageLogic::new(shape);
for (index, inlet) in inlets.iter().cloned().enumerate() {
logic
.set_handler(
&inlet.clone(),
Box::new(In {
inlet_id: inlet.id(),
inlet_index: index,
inlet: inlet.clone(),
all_inlets: inlets.clone(),
outlet: outlet.clone(),
state: Arc::clone(&state),
}),
)
.unwrap();
}
logic
.set_out_handler(
&outlet.clone(),
Box::new(Out {
inlets,
outlet: outlet.clone(),
state,
}),
)
.unwrap();
logic
}
}
#[derive(Clone)]
pub struct Partition<T: 'static> {
outputs: usize,
partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
eager_cancel: bool,
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> fmt::Debug for Partition<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Partition")
.field("outputs", &self.outputs)
.field("eager_cancel", &self.eager_cancel)
.finish_non_exhaustive()
}
}
impl<T: 'static> Partition<T> {
#[must_use]
pub fn new<F>(outputs: usize, partitioner: F) -> Self
where
F: Fn(&T) -> usize + Send + Sync + 'static,
{
Self::new_with_eager_cancel(outputs, partitioner, false)
}
#[must_use]
pub fn new_with_eager_cancel<F>(outputs: usize, partitioner: F, eager_cancel: bool) -> Self
where
F: Fn(&T) -> usize + Send + Sync + 'static,
{
assert!(
outputs > 0,
"partition output count must be greater than zero"
);
Self {
outputs,
partitioner: Arc::new(partitioner),
eager_cancel,
_marker: PhantomData,
}
}
}
impl<T> GraphStage for Partition<T>
where
T: Clone + Send + 'static,
{
type Shape = FanOutShape<T, T>;
fn name(&self) -> &str {
"Partition"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
let inlet = allocator.inlet("Partition.in");
let outlets = (0..self.outputs)
.map(|index| allocator.outlet(format!("Partition.out{index}")))
.collect();
FanOutShape::new(inlet, outlets)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let partitioner_clone = Arc::clone(&self.partitioner);
let partitioner = Arc::new(move |dv: &DatumValue| -> usize {
let t: &T = dv
.as_any_ref()
.downcast_ref::<T>()
.expect("partition: wrong element type");
partitioner_clone(t)
});
StageSpec::partition(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
self.outputs,
partitioner,
self.eager_cancel,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
struct State<T> {
pending: Option<(usize, T)>,
upstream_closed: bool,
live_outlets: usize,
cancelled: Vec<bool>,
eager_cancel: bool,
}
fn any_live_demand<T>(
logic: &GraphStageLogic,
outlets: &[Outlet<T>],
cancelled: &[bool],
) -> bool
where
T: Clone + Send + 'static,
{
outlets
.iter()
.enumerate()
.any(|(index, outlet)| !cancelled[index] && logic.is_available(outlet))
}
struct In<T: 'static> {
inlet_id: PortId,
inlet: Inlet<T>,
outlets: Vec<Outlet<T>>,
partitioner: Arc<dyn Fn(&T) -> usize + Send + Sync>,
state: Arc<Mutex<State<T>>>,
}
impl<T> InHandler for In<T>
where
T: Clone + Send + 'static,
{
fn on_push(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let item: T = logic.grab_datum(self.inlet_id).and_then(|value| {
downcast_datum(value, "grab", || {
format!("inlet#{}", self.inlet_id.as_usize())
})
})?;
let idx = (self.partitioner)(&item);
if idx >= self.outlets.len() {
return Err(StreamError::Failed(format!(
"partitioner returned out-of-bounds index {idx} for {} outputs",
self.outlets.len()
)));
}
let mut pull_again = false;
{
let mut state = self.state.lock().expect("partition state poisoned");
if state.cancelled[idx] {
pull_again = !state.upstream_closed
&& any_live_demand(logic, &self.outlets, &state.cancelled);
} else if logic.is_available(&self.outlets[idx]) {
logic.push(&self.outlets[idx], item)?;
pull_again = !state.upstream_closed
&& any_live_demand(logic, &self.outlets, &state.cancelled);
} else {
state.pending = Some((idx, item));
}
}
if pull_again && !logic.has_been_pulled(&self.inlet) {
logic.pull(&self.inlet)?;
}
Ok(())
}
fn on_upstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let complete_now = {
let mut state = self.state.lock().expect("partition state poisoned");
state.upstream_closed = true;
state.pending.is_none()
};
if complete_now {
for outlet in &self.outlets {
if !logic.is_closed(outlet) {
logic.complete(outlet)?;
}
}
}
Ok(())
}
}
struct Out<T: 'static> {
index: usize,
inlet: Inlet<T>,
outlets: Vec<Outlet<T>>,
state: Arc<Mutex<State<T>>>,
}
impl<T> OutHandler for Out<T>
where
T: Clone + Send + 'static,
{
fn on_pull(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let mut complete_now = false;
let pending = {
let mut state = self.state.lock().expect("partition state poisoned");
if let Some((idx, _)) = &state.pending
&& *idx == self.index
{
state.pending.take()
} else {
None
}
};
if let Some((_, item)) = pending {
logic.push(&self.outlets[self.index], item)?;
let state = self.state.lock().expect("partition state poisoned");
if state.upstream_closed {
complete_now = true;
} else if any_live_demand(logic, &self.outlets, &state.cancelled)
&& !logic.has_been_pulled(&self.inlet)
{
logic.pull(&self.inlet)?;
}
} else {
let state = self.state.lock().expect("partition state poisoned");
if state.upstream_closed {
complete_now = true;
} else if any_live_demand(logic, &self.outlets, &state.cancelled)
&& !logic.has_been_pulled(&self.inlet)
{
logic.pull(&self.inlet)?;
}
}
if complete_now {
for outlet in &self.outlets {
if !logic.is_closed(outlet) {
logic.complete(outlet)?;
}
}
}
Ok(())
}
fn on_downstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let (cancel_stage, clear_pending) = {
let mut state = self.state.lock().expect("partition state poisoned");
if state.cancelled[self.index] {
return Ok(());
}
state.cancelled[self.index] = true;
state.live_outlets -= 1;
let clear_pending = state
.pending
.as_ref()
.is_some_and(|(idx, _)| *idx == self.index);
let cancel_stage = state.eager_cancel || state.live_outlets == 0;
if clear_pending {
state.pending = None;
}
(cancel_stage, clear_pending)
};
if cancel_stage {
logic.complete_stage()?;
} else if clear_pending
&& !logic.has_been_pulled(&self.inlet)
&& !logic.is_closed(&self.inlet)
{
let state = self.state.lock().expect("partition state poisoned");
if any_live_demand(logic, &self.outlets, &state.cancelled) {
logic.pull(&self.inlet)?;
}
}
Ok(())
}
}
let inlet = shape.inlet();
let outlets = shape.outlets_vec();
let state = Arc::new(Mutex::new(State {
pending: None,
upstream_closed: false,
live_outlets: outlets.len(),
cancelled: vec![false; outlets.len()],
eager_cancel: self.eager_cancel,
}));
let mut logic = GraphStageLogic::new(shape);
logic
.set_handler(
&inlet,
Box::new(In {
inlet_id: inlet.id(),
inlet: inlet.clone(),
outlets: outlets.clone(),
partitioner: Arc::clone(&self.partitioner),
state: Arc::clone(&state),
}),
)
.unwrap();
for (index, outlet) in outlets.iter().cloned().enumerate() {
logic
.set_out_handler(
&outlet,
Box::new(Out {
index,
inlet: inlet.clone(),
outlets: outlets.clone(),
state: Arc::clone(&state),
}),
)
.unwrap();
}
logic
}
}
#[derive(Clone, Debug)]
pub struct Unzip<A: 'static, B: 'static> {
_marker: PhantomData<fn() -> (A, B)>,
}
impl<A: 'static, B: 'static> Unzip<A, B> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<A: 'static, B: 'static> Default for Unzip<A, B> {
fn default() -> Self {
Self::new()
}
}
impl<A, B> GraphStage for Unzip<A, B>
where
A: Clone + Send + 'static,
B: Clone + Send + 'static,
{
type Shape = FanOutShape2<(A, B), A, B>;
fn name(&self) -> &str {
"Unzip"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
FanOutShape2::new(
allocator.inlet("Unzip.in"),
allocator.outlet("Unzip.out0"),
allocator.outlet("Unzip.out1"),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let split = Arc::new(|dv: DatumValue| -> (DatumValue, DatumValue) {
let pair: (A, B) =
downcast_datum(dv, "unzip", || "Unzip.in").expect("unzip: wrong element type");
(datum(pair.0), datum(pair.1))
});
#[allow(clippy::type_complexity)]
let typed_split_fn: Arc<dyn Fn((A, B)) -> (A, B) + Send + Sync> =
Arc::new(|pair: (A, B)| pair);
let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
StageSpec::unzip(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
split,
typed_split,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
UnzipWith::new(|pair: (A, B)| pair).create_logic(shape)
}
}
#[derive(Clone)]
pub struct UnzipWith<In: 'static, Out0: 'static, Out1: 'static> {
split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
_marker: PhantomData<fn(In) -> (Out0, Out1)>,
}
impl<In: 'static, Out0: 'static, Out1: 'static> fmt::Debug for UnzipWith<In, Out0, Out1> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("UnzipWith").finish_non_exhaustive()
}
}
impl<In: 'static, Out0: 'static, Out1: 'static> UnzipWith<In, Out0, Out1> {
#[must_use]
pub fn new<F>(split: F) -> Self
where
F: Fn(In) -> (Out0, Out1) + Send + Sync + 'static,
{
Self {
split: Arc::new(split),
_marker: PhantomData,
}
}
}
impl<In, Out0, Out1> GraphStage for UnzipWith<In, Out0, Out1>
where
In: Clone + Send + 'static,
Out0: Clone + Send + 'static,
Out1: Clone + Send + 'static,
{
type Shape = FanOutShape2<In, Out0, Out1>;
fn name(&self) -> &str {
"UnzipWith"
}
fn allocate_shape(&self, allocator: &mut PortAllocator) -> Self::Shape {
FanOutShape2::new(
allocator.inlet("UnzipWith.in"),
allocator.outlet("UnzipWith.out0"),
allocator.outlet("UnzipWith.out1"),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
let split_fn = Arc::clone(&self.split);
let split = Arc::new(move |dv: DatumValue| -> (DatumValue, DatumValue) {
let value: In = downcast_datum(dv, "unzip_with", || "UnzipWith.in")
.expect("unzip-with: wrong element type");
let (out0, out1) = split_fn(value);
(datum(out0), datum(out1))
});
let typed_split_fn: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync> = Arc::clone(&self.split);
let typed_split: Arc<StageTypedUnzipFn> = Arc::new(typed_split_fn);
StageSpec::unzip(
Arc::from(self.name()),
shape.inlets(),
shape.outlets(),
split,
typed_split,
)
}
fn create_logic(&self, shape: &Self::Shape) -> GraphStageLogic {
struct State {
left_open: bool,
right_open: bool,
upstream_closed: bool,
}
struct InHandlerState<In: 'static, Out0: 'static, Out1: 'static> {
inlet_id: PortId,
inlet: Inlet<In>,
out0: Outlet<Out0>,
out1: Outlet<Out1>,
split: Arc<dyn Fn(In) -> (Out0, Out1) + Send + Sync>,
state: Arc<Mutex<State>>,
}
impl<In, Out0, Out1> InHandler for InHandlerState<In, Out0, Out1>
where
In: Clone + Send + 'static,
Out0: Clone + Send + 'static,
Out1: Clone + Send + 'static,
{
fn on_push(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
let value: In = logic.grab_datum(self.inlet_id).and_then(|value| {
downcast_datum(value, "grab", || {
format!("inlet#{}", self.inlet_id.as_usize())
})
})?;
let (left, right) = (self.split)(value);
let state = self.state.lock().expect("unzip-with state poisoned");
if state.left_open {
logic.push(&self.out0, left)?;
}
if state.right_open {
logic.push(&self.out1, right)?;
}
drop(state);
let state = self.state.lock().expect("unzip-with state poisoned");
let left_ready = !state.left_open || logic.is_available(&self.out0);
let right_ready = !state.right_open || logic.is_available(&self.out1);
if (state.left_open || state.right_open)
&& left_ready
&& right_ready
&& !logic.has_been_pulled(&self.inlet)
{
logic.pull(&self.inlet)?;
}
Ok(())
}
fn on_upstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_inlet: AnyInlet,
) -> StreamResult<()> {
self.state
.lock()
.expect("unzip-with state poisoned")
.upstream_closed = true;
if !logic.is_closed(&self.out0) {
logic.complete(&self.out0)?;
}
if !logic.is_closed(&self.out1) {
logic.complete(&self.out1)?;
}
Ok(())
}
}
struct Out<In: 'static, Out0: 'static, Out1: 'static> {
is_left: bool,
inlet: Inlet<In>,
out0: Outlet<Out0>,
out1: Outlet<Out1>,
state: Arc<Mutex<State>>,
}
impl<In, Out0, Out1> OutHandler for Out<In, Out0, Out1>
where
In: Clone + Send + 'static,
Out0: Clone + Send + 'static,
Out1: Clone + Send + 'static,
{
fn on_pull(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let state = self.state.lock().expect("unzip-with state poisoned");
let left_ready = !state.left_open || logic.is_available(&self.out0);
let right_ready = !state.right_open || logic.is_available(&self.out1);
if state.upstream_closed {
drop(state);
if !logic.is_closed(&self.out0) {
logic.complete(&self.out0)?;
}
if !logic.is_closed(&self.out1) {
logic.complete(&self.out1)?;
}
} else if (state.left_open || state.right_open)
&& left_ready
&& right_ready
&& !logic.has_been_pulled(&self.inlet)
{
drop(state);
logic.pull(&self.inlet)?;
}
Ok(())
}
fn on_downstream_finish(
&mut self,
logic: &mut GraphStageLogic,
_outlet: AnyOutlet,
) -> StreamResult<()> {
let mut state = self.state.lock().expect("unzip-with state poisoned");
if self.is_left {
state.left_open = false;
} else {
state.right_open = false;
}
if !state.left_open && !state.right_open {
logic.complete_stage()?;
return Ok(());
}
let left_ready = !state.left_open || logic.is_available(&self.out0);
let right_ready = !state.right_open || logic.is_available(&self.out1);
if !state.upstream_closed
&& (state.left_open || state.right_open)
&& left_ready
&& right_ready
&& !logic.has_been_pulled(&self.inlet)
{
logic.pull(&self.inlet)?;
}
Ok(())
}
}
let inlet = shape.inlet();
let out0 = shape.out0();
let out1 = shape.out1();
let state = Arc::new(Mutex::new(State {
left_open: true,
right_open: true,
upstream_closed: false,
}));
let mut logic = GraphStageLogic::new(shape);
logic
.set_handler(
&inlet,
Box::new(InHandlerState {
inlet_id: inlet.id(),
inlet: inlet.clone(),
out0: out0.clone(),
out1: out1.clone(),
split: Arc::clone(&self.split),
state: Arc::clone(&state),
}),
)
.unwrap();
logic
.set_out_handler(
&out0,
Box::new(Out {
is_left: true,
inlet: inlet.clone(),
out0: out0.clone(),
out1: out1.clone(),
state: Arc::clone(&state),
}),
)
.unwrap();
logic
.set_out_handler(
&out1.clone(),
Box::new(Out {
is_left: false,
inlet: inlet.clone(),
out0: out0.clone(),
out1: out1.clone(),
state,
}),
)
.unwrap();
logic
}
}
#[derive(Clone, Debug)]
pub struct AsyncBoundary<T: 'static> {
_marker: PhantomData<fn() -> T>,
}
impl<T: 'static> AsyncBoundary<T> {
#[must_use]
pub fn new() -> Self {
Self {
_marker: PhantomData,
}
}
}
impl<T: 'static> Default for AsyncBoundary<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> GraphStage for AsyncBoundary<T>
where
T: Clone + Send + 'static,
{
type Shape = FlowShape<T, T>;
fn name(&self) -> &str {
"AsyncBoundary"
}
fn allocate_shape(&self, _allocator: &mut PortAllocator) -> Self::Shape {
let first_id = next_port_id_block(2);
FlowShape::new(
Inlet::with_arc_name(first_id, async_boundary_inlet_name()),
Outlet::with_arc_name(first_id.offset(1), async_boundary_outlet_name()),
)
}
fn stage_spec(&self, shape: &Self::Shape) -> StageSpec {
self.stage_spec_with_ports(shape, shape.inlets(), shape.outlets())
}
fn stage_spec_with_ports(
&self,
_shape: &Self::Shape,
inlets: Vec<AnyInlet>,
outlets: Vec<AnyOutlet>,
) -> StageSpec {
StageSpec::async_boundary(async_boundary_stage_name(), inlets, outlets)
}
}