use super::*;
pub trait WireSpec {
fn apply(self, builder: &mut GraphBuilder) -> StreamResult<()>;
}
pub trait WireDsl: Shape + Clone {
fn to<D>(&self, inlet: D) -> WirePair<Self, D>
where
Self: AutoOutletEndpoint,
{
WirePair::new(self.clone(), inlet)
}
#[must_use]
fn out(&self, index: usize) -> OutletCursor<Self> {
OutletCursor::new(self.clone(), index)
}
#[must_use]
fn in_(&self, index: usize) -> InletCursor<Self> {
InletCursor::new(self.clone(), index)
}
}
impl<S> WireDsl for S where S: Shape + Clone {}
#[diagnostic::on_unimplemented(
message = "`{Self}` cannot be auto-wired as a graph DSL outlet",
note = "MergePreferred and Bidi shapes are explicit-only; use `.preferred()` / `.secondary(i)` / explicit `.out(i)` cursors"
)]
#[doc(hidden)]
pub trait AutoOutletEndpoint: Shape {
fn auto_outlet(&self, builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
select_auto_outlet(&self.outlets(), builder)
}
}
#[diagnostic::on_unimplemented(
message = "`{Self}` cannot be auto-wired as a graph DSL inlet",
note = "MergePreferred and Bidi shapes are explicit-only; use `.preferred()` / `.secondary(i)` / explicit `.in_(i)` cursors"
)]
#[doc(hidden)]
pub trait AutoInletEndpoint: Shape {
fn auto_inlet(&self, builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
select_auto_inlet(&self.inlets(), builder)
}
}
impl<Out: 'static> AutoOutletEndpoint for SourceShape<Out> {}
impl<In: 'static, Out: 'static> AutoOutletEndpoint for FlowShape<In, Out> {}
impl<In: 'static, Out: 'static> AutoOutletEndpoint for FanInShape<In, Out> {}
impl<In: 'static, Out: 'static> AutoOutletEndpoint for FanOutShape<In, Out> {}
impl<In: 'static, Out0: 'static, Out1: 'static> AutoOutletEndpoint
for FanOutShape2<In, Out0, Out1>
{
}
impl<Left: 'static, Right: 'static> AutoOutletEndpoint for ZipShape<Left, Right> {}
impl<In: 'static> AutoInletEndpoint for SinkShape<In> {}
impl<In: 'static, Out: 'static> AutoInletEndpoint for FlowShape<In, Out> {}
impl<In: 'static, Out: 'static> AutoInletEndpoint for FanInShape<In, Out> {}
impl<In: 'static, Out: 'static> AutoInletEndpoint for FanOutShape<In, Out> {}
impl<In: 'static, Out0: 'static, Out1: 'static> AutoInletEndpoint for FanOutShape2<In, Out0, Out1> {}
impl<Left: 'static, Right: 'static> AutoInletEndpoint for ZipShape<Left, Right> {}
#[derive(Clone, Debug)]
pub struct OutletCursor<S> {
shape: S,
index: usize,
}
impl<S> OutletCursor<S> {
#[must_use]
pub fn new(shape: S, index: usize) -> Self {
Self { shape, index }
}
#[must_use]
pub const fn index(&self) -> usize {
self.index
}
#[must_use]
pub fn to<D>(self, inlet: D) -> WirePair<Self, D> {
WirePair::new(self, inlet)
}
}
#[derive(Clone, Debug)]
pub struct InletCursor<S> {
shape: S,
index: usize,
}
impl<S> InletCursor<S> {
#[must_use]
pub fn new(shape: S, index: usize) -> Self {
Self { shape, index }
}
#[must_use]
pub const fn index(&self) -> usize {
self.index
}
}
#[derive(Clone, Debug)]
pub struct WirePair<Out, In> {
outlet: Out,
inlet: In,
}
impl<Out, In> WirePair<Out, In> {
#[must_use]
pub fn new(outlet: Out, inlet: In) -> Self {
Self { outlet, inlet }
}
}
impl<Out, In> WireSpec for WirePair<Out, In>
where
Out: WireOutletEndpoint,
In: WireInletEndpoint,
{
fn apply(self, builder: &mut GraphBuilder) -> StreamResult<()> {
let outlet = self.outlet.select_outlet(builder)?;
let inlet = self.inlet.select_inlet(builder)?;
builder
.connect_any_unrecorded(outlet.port.clone(), inlet.port.clone())
.map_err(|error| {
StreamError::GraphValidation(format!(
"{} -> {}: {}",
outlet.label, inlet.label, error
))
})
}
}
#[doc(hidden)]
#[derive(Clone, Debug)]
pub struct SelectedOutlet {
port: AnyOutlet,
label: String,
}
#[doc(hidden)]
#[derive(Clone, Debug)]
pub struct SelectedInlet {
port: AnyInlet,
label: String,
}
trait WireOutletEndpoint {
fn select_outlet(self, builder: &GraphBuilder) -> StreamResult<SelectedOutlet>;
}
trait WireInletEndpoint {
fn select_inlet(self, builder: &GraphBuilder) -> StreamResult<SelectedInlet>;
}
macro_rules! impl_auto_outlet_endpoint {
($($params:ident),*; $shape:ty) => {
impl<$($params: 'static),*> WireOutletEndpoint for $shape {
fn select_outlet(self, builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
self.auto_outlet(builder)
}
}
impl<$($params: 'static),*> WireOutletEndpoint for &$shape {
fn select_outlet(self, builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
(*self).auto_outlet(builder)
}
}
};
}
macro_rules! impl_auto_inlet_endpoint {
($($params:ident),*; $shape:ty) => {
impl<$($params: 'static),*> WireInletEndpoint for $shape {
fn select_inlet(self, builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
self.auto_inlet(builder)
}
}
impl<$($params: 'static),*> WireInletEndpoint for &$shape {
fn select_inlet(self, builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
(*self).auto_inlet(builder)
}
}
};
}
impl_auto_outlet_endpoint!(Out; SourceShape<Out>);
impl_auto_outlet_endpoint!(In, Out; FlowShape<In, Out>);
impl_auto_outlet_endpoint!(In, Out; FanInShape<In, Out>);
impl_auto_outlet_endpoint!(In, Out; FanOutShape<In, Out>);
impl_auto_outlet_endpoint!(In, Out0, Out1; FanOutShape2<In, Out0, Out1>);
impl_auto_outlet_endpoint!(Left, Right; ZipShape<Left, Right>);
impl_auto_inlet_endpoint!(In; SinkShape<In>);
impl_auto_inlet_endpoint!(In, Out; FlowShape<In, Out>);
impl_auto_inlet_endpoint!(In, Out; FanInShape<In, Out>);
impl_auto_inlet_endpoint!(In, Out; FanOutShape<In, Out>);
impl_auto_inlet_endpoint!(In, Out0, Out1; FanOutShape2<In, Out0, Out1>);
impl_auto_inlet_endpoint!(Left, Right; ZipShape<Left, Right>);
impl<S> WireOutletEndpoint for OutletCursor<S>
where
S: Shape,
{
fn select_outlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
select_indexed_outlet(&self.shape.outlets(), self.index)
}
}
impl<S> WireOutletEndpoint for &OutletCursor<S>
where
S: Shape,
{
fn select_outlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
select_indexed_outlet(&self.shape.outlets(), self.index)
}
}
impl<T: 'static> WireOutletEndpoint for Outlet<T> {
fn select_outlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
let port = self.erase();
let label = indexed_port_label(port.name(), 0, PortKind::Outlet);
Ok(SelectedOutlet { port, label })
}
}
impl<T: 'static> WireOutletEndpoint for &Outlet<T> {
fn select_outlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedOutlet> {
let port = self.erase();
let label = indexed_port_label(port.name(), 0, PortKind::Outlet);
Ok(SelectedOutlet { port, label })
}
}
impl<S> WireInletEndpoint for InletCursor<S>
where
S: Shape,
{
fn select_inlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
select_indexed_inlet(&self.shape.inlets(), self.index)
}
}
impl<S> WireInletEndpoint for &InletCursor<S>
where
S: Shape,
{
fn select_inlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
select_indexed_inlet(&self.shape.inlets(), self.index)
}
}
impl<T: 'static> WireInletEndpoint for Inlet<T> {
fn select_inlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
let port = self.erase();
let label = indexed_port_label(port.name(), 0, PortKind::Inlet);
Ok(SelectedInlet { port, label })
}
}
impl<T: 'static> WireInletEndpoint for &Inlet<T> {
fn select_inlet(self, _builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
let port = self.erase();
let label = indexed_port_label(port.name(), 0, PortKind::Inlet);
Ok(SelectedInlet { port, label })
}
}
fn select_auto_outlet(
outlets: &[AnyOutlet],
builder: &GraphBuilder,
) -> StreamResult<SelectedOutlet> {
outlets
.iter()
.enumerate()
.find(|(_, outlet)| !builder.is_outlet_connected(outlet))
.map(|(index, outlet)| SelectedOutlet {
port: outlet.clone(),
label: indexed_port_label(outlet.name(), index, PortKind::Outlet),
})
.ok_or_else(|| {
StreamError::GraphValidation(format!(
"{}: no unconnected outlet",
missing_port_label(
outlets.iter().map(AnyOutlet::name),
outlets.len(),
PortKind::Outlet
)
))
})
}
fn select_auto_inlet(inlets: &[AnyInlet], builder: &GraphBuilder) -> StreamResult<SelectedInlet> {
inlets
.iter()
.enumerate()
.find(|(_, inlet)| !builder.is_inlet_connected(inlet))
.map(|(index, inlet)| SelectedInlet {
port: inlet.clone(),
label: indexed_port_label(inlet.name(), index, PortKind::Inlet),
})
.ok_or_else(|| {
StreamError::GraphValidation(format!(
"{}: no unconnected inlet",
missing_port_label(
inlets.iter().map(AnyInlet::name),
inlets.len(),
PortKind::Inlet
)
))
})
}
fn select_indexed_outlet(outlets: &[AnyOutlet], index: usize) -> StreamResult<SelectedOutlet> {
outlets
.get(index)
.map(|outlet| SelectedOutlet {
port: outlet.clone(),
label: indexed_port_label(outlet.name(), index, PortKind::Outlet),
})
.ok_or_else(|| {
StreamError::GraphValidation(format!(
"{}: no outlet at cursor index {index}",
missing_port_label(outlets.iter().map(AnyOutlet::name), index, PortKind::Outlet)
))
})
}
fn select_indexed_inlet(inlets: &[AnyInlet], index: usize) -> StreamResult<SelectedInlet> {
inlets
.get(index)
.map(|inlet| SelectedInlet {
port: inlet.clone(),
label: indexed_port_label(inlet.name(), index, PortKind::Inlet),
})
.ok_or_else(|| {
StreamError::GraphValidation(format!(
"{}: no inlet at cursor index {index}",
missing_port_label(inlets.iter().map(AnyInlet::name), index, PortKind::Inlet)
))
})
}
fn missing_port_label<'a>(
mut names: impl Iterator<Item = &'a str>,
index: usize,
kind: PortKind,
) -> String {
names
.next()
.map(|name| indexed_port_label(name, index, kind))
.unwrap_or_else(|| match kind {
PortKind::Inlet => format!("inlet[{index}]"),
PortKind::Outlet => format!("outlet[{index}]"),
})
}
fn indexed_port_label(name: &str, index: usize, _kind: PortKind) -> String {
let suffix_start = name
.char_indices()
.rev()
.find_map(|(offset, ch)| (!ch.is_ascii_digit()).then_some(offset + ch.len_utf8()))
.unwrap_or(0);
if suffix_start < name.len() {
format!("{}[{index}]", &name[..suffix_start])
} else {
format!("{name}[{index}]")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wire_dsl_matches_connect_for_linear_identity_map_chain() {
let connect_graph = GraphDsl::try_create(|builder| {
let identity = builder.add(Identity::<u64>::new());
let plus_one = builder.add(MapStage::new(|item: u64| item + 1));
let times_two = builder.add(MapStage::new(|item: u64| item * 2));
builder.connect(identity.outlet(), plus_one.inlet())?;
builder.connect(plus_one.outlet(), times_two.inlet())?;
Ok(FlowShape::new(identity.inlet(), times_two.outlet()))
})
.unwrap();
let wire_graph = GraphDsl::try_create(|builder| {
let identity = builder.add(Identity::<u64>::new());
let plus_one = builder.add(MapStage::new(|item: u64| item + 1));
let times_two = builder.add(MapStage::new(|item: u64| item * 2));
builder
.wire(identity.to(&plus_one))
.wire(plus_one.to(×_two));
Ok(FlowShape::new(identity.inlet(), times_two.outlet()))
})
.unwrap();
let input = 0_u64..32;
assert_eq!(
wire_graph.run_with_input(input.clone()).unwrap(),
connect_graph.run_with_input(input).unwrap()
);
}
#[test]
fn wire_dsl_matches_connect_for_broadcast_zip_diamond() {
let connect_graph = GraphDsl::try_create(|builder| {
let broadcast = builder.add(Broadcast::<i32>::new(2));
let zip = builder.add(Zip::<i32, i32>::new());
builder.connect(broadcast.outlet(0)?, zip.in0())?;
builder.connect(broadcast.outlet(1)?, zip.in1())?;
Ok(FlowShape::new(broadcast.inlet(), zip.outlet()))
})
.unwrap();
let wire_graph = GraphDsl::try_create(|builder| {
let broadcast = builder.add(Broadcast::<i32>::new(2));
let zip = builder.add(Zip::<i32, i32>::new());
builder.wire(broadcast.to(&zip)).wire(broadcast.to(&zip));
Ok(FlowShape::new(broadcast.inlet(), zip.outlet()))
})
.unwrap();
let input = 0..16;
assert_eq!(
wire_graph.run_with_input(input.clone()).unwrap(),
connect_graph.run_with_input(input).unwrap()
);
}
#[test]
fn wire_dsl_matches_connect_for_balance_merge() {
let connect_graph = GraphDsl::try_create(|builder| {
let balance = builder.add(Balance::<i32>::new(2));
let merge = builder.add(Merge::<i32>::new(2));
builder.connect(balance.outlet(0)?, merge.inlet(0)?)?;
builder.connect(balance.outlet(1)?, merge.inlet(1)?)?;
Ok(FlowShape::new(balance.inlet(), merge.outlet()))
})
.unwrap();
let wire_graph = GraphDsl::try_create(|builder| {
let balance = builder.add(Balance::<i32>::new(2));
let merge = builder.add(Merge::<i32>::new(2));
builder.wire(balance.to(&merge)).wire(balance.to(&merge));
Ok(FlowShape::new(balance.inlet(), merge.outlet()))
})
.unwrap();
let input = 0..32;
assert_eq!(
wire_graph.run_with_input(input.clone()).unwrap(),
connect_graph.run_with_input(input).unwrap()
);
}
#[test]
fn wire_dsl_matches_connect_for_partition_merge() {
let connect_graph = GraphDsl::try_create(|builder| {
let partition = builder.add(Partition::<i32>::new(2, |item| (*item % 2) as usize));
let merge = builder.add(Merge::<i32>::new(2));
builder.connect(partition.outlet(0)?, merge.inlet(0)?)?;
builder.connect(partition.outlet(1)?, merge.inlet(1)?)?;
Ok(FlowShape::new(partition.inlet(), merge.outlet()))
})
.unwrap();
let wire_graph = GraphDsl::try_create(|builder| {
let partition = builder.add(Partition::<i32>::new(2, |item| (*item % 2) as usize));
let merge = builder.add(Merge::<i32>::new(2));
builder
.wire(partition.to(&merge))
.wire(partition.to(&merge));
Ok(FlowShape::new(partition.inlet(), merge.outlet()))
})
.unwrap();
let input = 0..32;
assert_eq!(
wire_graph.run_with_input(input.clone()).unwrap(),
connect_graph.run_with_input(input).unwrap()
);
}
#[test]
fn wire_dsl_matches_connect_for_merge_preferred_feedback_cycle() {
let connect_graph = GraphDsl::try_create(|builder| {
let merge = builder.add(MergePreferred::<u64>::new(1));
let broadcast = builder.add(Broadcast::<u64>::new(2));
let buffer = builder.add(Buffer::<u64>::new(8, OverflowStrategy::Backpressure));
let positive = builder.add(TakeWhile::<u64>::new(|item| *item > 0));
let decrement = builder.add(MapStage::new(|item: u64| item - 1));
builder.connect(merge.outlet(), broadcast.inlet())?;
builder.connect(broadcast.outlet(1)?, buffer.inlet())?;
builder.connect(buffer.outlet(), positive.inlet())?;
builder.connect(positive.outlet(), decrement.inlet())?;
builder.connect(decrement.outlet(), merge.preferred())?;
Ok(FlowShape::new(merge.secondary(0)?, broadcast.outlet(0)?))
})
.unwrap();
let wire_graph = GraphDsl::try_create(|builder| {
let merge = builder.add(MergePreferred::<u64>::new(1));
let broadcast = builder.add(Broadcast::<u64>::new(2));
let buffer = builder.add(Buffer::<u64>::new(8, OverflowStrategy::Backpressure));
let positive = builder.add(TakeWhile::<u64>::new(|item| *item > 0));
let decrement = builder.add(MapStage::new(|item: u64| item - 1));
builder
.wire(merge.out(0).to(&broadcast))
.wire(broadcast.out(1).to(&buffer))
.wire(buffer.to(&positive))
.wire(positive.to(&decrement))
.wire(decrement.to(&merge.preferred()));
Ok(FlowShape::new(merge.secondary(0)?, broadcast.outlet(0)?))
})
.unwrap();
assert_eq!(
wire_graph.run_with_input([5]).unwrap(),
connect_graph.run_with_input([5]).unwrap()
);
}
#[test]
fn wire_dsl_advances_auto_ports_and_supports_explicit_cursors() {
let auto = GraphDsl::try_create(|builder| {
let broadcast = builder.add(Broadcast::<i32>::new(2));
let zip = builder.add(Zip::<i32, i32>::new());
builder.wire(broadcast.to(&zip)).wire(broadcast.to(&zip));
Ok(FlowShape::new(broadcast.inlet(), zip.outlet()))
})
.unwrap();
assert_eq!(auto.edge_count(), 2);
let explicit = GraphDsl::try_create(|builder| {
let broadcast = builder.add(Broadcast::<i32>::new(2));
let zip = builder.add(Zip::<i32, i32>::new());
builder
.wire(broadcast.out(1).to(&zip.in_(1)))
.wire(broadcast.out(0).to(&zip.in_(0)));
Ok(FlowShape::new(broadcast.inlet(), zip.outlet()))
})
.unwrap();
assert_eq!(explicit.edge_count(), 2);
let input = [1, 2, 3];
assert_eq!(
explicit.run_with_input(input).unwrap(),
auto.run_with_input(input).unwrap()
);
}
#[test]
fn wire_dsl_records_type_mismatch_with_selected_port_context() {
let error = GraphDsl::create(|builder| {
let first = builder.add(Identity::<u64>::new());
let second = builder.add(Identity::<String>::new());
builder.wire(first.to(&second));
FlowShape::new(first.inlet(), second.outlet())
})
.unwrap_err();
let message = error.to_string();
assert!(
message.contains("Identity.out[0] -> Identity.in[0]"),
"{message}"
);
assert!(
message.contains("cannot connect outlet Identity.out"),
"{message}"
);
assert!(message.contains("to inlet Identity.in"), "{message}");
}
#[test]
fn wire_dsl_rejects_duplicate_ports_and_no_open_auto_port() {
let duplicate = GraphDsl::create(|builder| {
let first = builder.add(Identity::<i32>::new());
let second = builder.add(Identity::<i32>::new());
let third = builder.add(Identity::<i32>::new());
builder
.wire(first.out(0).to(&second))
.wire(first.out(0).to(&third));
FlowShape::new(first.inlet(), third.outlet())
})
.unwrap_err()
.to_string();
assert!(duplicate.contains("Identity.out[0]"), "{duplicate}");
assert!(duplicate.contains("already connected"), "{duplicate}");
let no_open = GraphDsl::create(|builder| {
let broadcast = builder.add(Broadcast::<i32>::new(2));
let merge = builder.add(Merge::<i32>::new(2));
builder
.wire(broadcast.to(&merge))
.wire(broadcast.to(&merge))
.wire(broadcast.to(&merge));
FlowShape::new(broadcast.inlet(), merge.outlet())
})
.unwrap_err()
.to_string();
assert!(
no_open.contains("Broadcast.out[2]: no unconnected outlet"),
"{no_open}"
);
}
#[test]
fn wire_dsl_defers_wire_errors_but_try_wire_fails_fast() {
let deferred = GraphDsl::try_create(|builder| {
let first = builder.add(Identity::<i32>::new());
let second = builder.add(Identity::<u64>::new());
builder.wire(first.to(&second));
Ok(FlowShape::new(first.inlet(), second.outlet()))
})
.unwrap_err()
.to_string();
assert!(
deferred.contains("Identity.out[0] -> Identity.in[0]"),
"{deferred}"
);
let immediate = GraphDsl::try_create(|builder| {
let first = builder.add(Identity::<i32>::new());
let second = builder.add(Identity::<u64>::new());
builder.try_wire(first.to(&second))?;
Ok(FlowShape::new(first.inlet(), second.outlet()))
})
.unwrap_err()
.to_string();
assert!(
immediate.contains("Identity.out[0] -> Identity.in[0]"),
"{immediate}"
);
assert!(
!immediate.contains("result shape"),
"try_wire should return before finish aggregation: {immediate}"
);
}
}