use super::*;
use crate::Attributes;
type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
#[derive(Clone, Debug)]
struct PortRecord {
kind: PortKind,
type_id: TypeId,
type_name: &'static str,
name: Arc<str>,
}
#[derive(Clone, Debug)]
pub(super) struct Edge {
pub(super) outlet: PortId,
pub(super) inlet: PortId,
}
#[derive(Clone)]
pub(super) struct StageRecord {
pub(super) spec: StageSpec,
pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
}
impl std::fmt::Debug for StageRecord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StageRecord")
.field("spec", &self.spec)
.field("has_logic", &self.logic_factory.is_some())
.finish()
}
}
#[derive(Debug, Default)]
pub struct GraphBuilder {
allocator: PortAllocator,
ports: HashMap<PortId, PortRecord>,
stages: Vec<StageRecord>,
edges: Vec<Edge>,
errors: Vec<StreamError>,
}
impl GraphBuilder {
#[must_use]
pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
self.add_with_attributes(stage, Attributes::default())
}
#[must_use]
pub fn add_with_attributes<G: GraphStage>(
&mut self,
stage: G,
attributes: Attributes,
) -> G::Shape {
let shape = stage.allocate_shape(&mut self.allocator);
let inlets = shape.inlets();
let outlets = shape.outlets();
self.ports.reserve(inlets.len() + outlets.len());
for inlet in &inlets {
self.register_inlet(inlet);
}
for outlet in &outlets {
self.register_outlet(outlet);
}
let spec = stage
.stage_spec_with_ports(&shape, inlets, outlets)
.add_attributes(attributes);
let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
let shape_clone = shape.clone();
Some(Arc::new(move || stage.create_logic(&shape_clone))
as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
} else {
None
};
self.stages.push(StageRecord {
spec,
logic_factory,
});
shape
}
#[must_use]
pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
self.add_with_attributes(stage, Attributes::named(name))
}
pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
self.connect_any(outlet.erase(), inlet.erase())
}
pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
match self.validate_connection(&outlet, &inlet) {
Ok(()) => {
self.edges.push(Edge {
outlet: outlet.id(),
inlet: inlet.id(),
});
Ok(())
}
Err(error) => {
self.errors.push(error.clone());
Err(error)
}
}
}
pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
graph.build(self)
}
fn register_inlet(&mut self, inlet: &AnyInlet) {
self.ports.insert(
inlet.id(),
PortRecord {
kind: PortKind::Inlet,
type_id: inlet.type_id(),
type_name: inlet.type_name(),
name: Arc::clone(&inlet.name),
},
);
}
fn register_outlet(&mut self, outlet: &AnyOutlet) {
self.ports.insert(
outlet.id(),
PortRecord {
kind: PortKind::Outlet,
type_id: outlet.type_id(),
type_name: outlet.type_name(),
name: Arc::clone(&outlet.name),
},
);
}
fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
})?;
let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
})?;
if outlet_record.kind != PortKind::Outlet {
return Err(StreamError::GraphValidation(format!(
"{} is not an outlet",
outlet_record.name
)));
}
if inlet_record.kind != PortKind::Inlet {
return Err(StreamError::GraphValidation(format!(
"{} is not an inlet",
inlet_record.name
)));
}
if outlet_record.type_id != inlet_record.type_id {
return Err(StreamError::GraphValidation(format!(
"cannot connect outlet {} ({}) to inlet {} ({})",
outlet_record.name,
outlet_record.type_name,
inlet_record.name,
inlet_record.type_name
)));
}
if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
return Err(StreamError::GraphValidation(format!(
"outlet {} is already connected",
outlet_record.name
)));
}
if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
return Err(StreamError::GraphValidation(format!(
"inlet {} is already connected",
inlet_record.name
)));
}
Ok(())
}
fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
let mut errors = self.errors;
let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
let connected_outlets: HashSet<PortId> =
self.edges.iter().map(|edge| edge.outlet).collect();
let open_inlets: HashSet<PortId> = self
.ports
.iter()
.filter_map(|(id, port)| {
(port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
})
.collect();
let open_outlets: HashSet<PortId> = self
.ports
.iter()
.filter_map(|(id, port)| {
(port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
})
.collect();
let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
for inlet in shape.inlets() {
match self.ports.get(&inlet.id()) {
Some(port)
if port.kind == PortKind::Inlet
&& port.type_id == inlet.type_id()
&& port.name.as_ref() == inlet.name() => {}
Some(port) if port.kind == PortKind::Inlet => {
errors.push(StreamError::GraphValidation(format!(
"result shape inlet {} does not match registered inlet {} ({})",
inlet.name(),
port.name,
port.type_name
)));
}
Some(port) => errors.push(StreamError::GraphValidation(format!(
"result shape references non-inlet port {}",
port.name
))),
None => errors.push(StreamError::GraphValidation(format!(
"result shape references unknown inlet {}",
inlet.name()
))),
}
}
for outlet in shape.outlets() {
match self.ports.get(&outlet.id()) {
Some(port)
if port.kind == PortKind::Outlet
&& port.type_id == outlet.type_id()
&& port.name.as_ref() == outlet.name() => {}
Some(port) if port.kind == PortKind::Outlet => {
errors.push(StreamError::GraphValidation(format!(
"result shape outlet {} does not match registered outlet {} ({})",
outlet.name(),
port.name,
port.type_name
)));
}
Some(port) => errors.push(StreamError::GraphValidation(format!(
"result shape references non-outlet port {}",
port.name
))),
None => errors.push(StreamError::GraphValidation(format!(
"result shape references unknown outlet {}",
outlet.name()
))),
}
}
if open_inlets != result_inlets {
errors.push(StreamError::GraphValidation(format!(
"result shape inlets do not match open inlets: open={:?}, result={:?}",
describe_ports(&self.ports, &open_inlets),
describe_ports(&self.ports, &result_inlets)
)));
}
if open_outlets != result_outlets {
errors.push(StreamError::GraphValidation(format!(
"result shape outlets do not match open outlets: open={:?}, result={:?}",
describe_ports(&self.ports, &open_outlets),
describe_ports(&self.ports, &result_outlets)
)));
}
if graph_has_cycle(&self.stages, &self.edges) {
errors.push(StreamError::GraphValidation(
"graph contains a cycle; Datum still rejects cyclic fused graphs until WP-16 adds a demand-aware graph interpreter".into(),
));
}
if !errors.is_empty() {
return Err(StreamError::GraphValidation(
errors
.into_iter()
.map(|error| error.to_string())
.collect::<Vec<_>>()
.join("; "),
));
}
let segments = compute_segments(&self.stages);
Ok(GraphBlueprint {
shape,
stages: self.stages,
edges: self.edges,
segments,
attributes: Attributes::default(),
})
}
}
fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
let mut names = ids
.iter()
.map(|id| {
ports
.get(id)
.map(|port| port.name.as_ref().to_owned())
.unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
})
.collect::<Vec<_>>();
names.sort();
names
}
fn graph_has_cycle(stages: &[StageRecord], edges: &[Edge]) -> bool {
let mut stage_of_inlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
let mut stage_of_outlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
for (index, stage) in stages.iter().enumerate() {
for inlet in &stage.spec.inlets {
stage_of_inlet.insert(inlet.id(), index);
}
for outlet in &stage.spec.outlets {
stage_of_outlet.insert(outlet.id(), index);
}
}
const NO_SUCCESSOR: usize = usize::MAX;
let mut first_successor = vec![NO_SUCCESSOR; stages.len()];
let mut successor_to = Vec::with_capacity(edges.len());
let mut next_successor = Vec::with_capacity(edges.len());
let mut indegree: Vec<usize> = vec![0; stages.len()];
for edge in edges {
if let (Some(&from), Some(&to)) = (
stage_of_outlet.get(&edge.outlet),
stage_of_inlet.get(&edge.inlet),
) {
successor_to.push(to);
next_successor.push(first_successor[from]);
first_successor[from] = successor_to.len() - 1;
indegree[to] += 1;
}
}
let mut stack: Vec<usize> = (0..stages.len()).filter(|&i| indegree[i] == 0).collect();
let mut visited = 0_usize;
while let Some(stage) = stack.pop() {
visited += 1;
let mut successor = first_successor[stage];
while successor != NO_SUCCESSOR {
let next = successor_to[successor];
indegree[next] -= 1;
if indegree[next] == 0 {
stack.push(next);
}
successor = next_successor[successor];
}
}
visited != stages.len()
}
fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
let mut segments = Vec::with_capacity(1);
let mut current = Vec::with_capacity(stages.len());
for (index, stage) in stages.iter().enumerate() {
if stage.spec.async_boundary && !current.is_empty() {
segments.push(FusedSegment {
stage_indices: std::mem::take(&mut current),
});
}
current.push(index);
if stage.spec.async_boundary {
segments.push(FusedSegment {
stage_indices: std::mem::take(&mut current),
});
}
}
if !current.is_empty() {
segments.push(FusedSegment {
stage_indices: current,
});
}
segments
}
pub struct GraphDsl;
impl GraphDsl {
pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
where
S: Shape,
F: FnOnce(&mut GraphBuilder) -> S,
{
let mut builder = GraphBuilder::default();
let shape = build(&mut builder);
builder.finish(shape)
}
pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
where
S: Shape,
F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
{
let mut builder = GraphBuilder::default();
let shape = build(&mut builder)?;
builder.finish(shape)
}
pub fn partial<S, F>(build: F) -> PartialGraph<S>
where
S: Shape,
F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
{
PartialGraph {
build: Arc::new(build),
attributes: Attributes::default(),
}
}
}
pub trait Graph {
type Shape: Shape;
fn shape(&self) -> Self::Shape;
}
#[derive(Clone, Debug)]
pub struct FusedSegment {
stage_indices: Vec<usize>,
}
impl FusedSegment {
#[must_use]
pub fn stage_indices(&self) -> &[usize] {
&self.stage_indices
}
}
pub struct GraphBlueprint<S: Shape> {
pub(super) shape: S,
pub(super) stages: Vec<StageRecord>,
pub(super) edges: Vec<Edge>,
pub(super) segments: Vec<FusedSegment>,
pub(super) attributes: Attributes,
}
impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
fn clone(&self) -> Self {
Self {
shape: self.shape.clone(),
stages: self.stages.clone(),
edges: self.edges.clone(),
segments: self.segments.clone(),
attributes: self.attributes.clone(),
}
}
}
impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("GraphBlueprint")
.field("shape", &self.shape)
.field("stages", &self.stages)
.field("edges", &self.edges)
.field("segments", &self.segments)
.field("attributes", &self.attributes)
.finish()
}
}
impl<S: Shape> GraphBlueprint<S> {
#[must_use]
pub fn shape(&self) -> S {
self.shape.clone()
}
#[must_use]
pub fn stage_count(&self) -> usize {
self.stages.len()
}
#[must_use]
pub fn edge_count(&self) -> usize {
self.edges.len()
}
#[must_use]
pub fn segments(&self) -> &[FusedSegment] {
&self.segments
}
#[must_use]
pub fn attributes(&self) -> &Attributes {
&self.attributes
}
#[must_use]
pub fn with_attributes(mut self, attributes: Attributes) -> Self {
self.attributes = attributes;
self
}
#[must_use]
pub fn add_attributes(mut self, attributes: Attributes) -> Self {
self.attributes = self.attributes.and(attributes);
self
}
#[must_use]
pub fn named(self, name: impl Into<String>) -> Self {
self.add_attributes(Attributes::named(name))
}
}
impl<S: Shape> Graph for GraphBlueprint<S> {
type Shape = S;
fn shape(&self) -> Self::Shape {
self.shape()
}
}
#[derive(Clone)]
pub struct PartialGraph<S: Shape> {
build: Arc<PartialGraphBuilder<S>>,
attributes: Attributes,
}
impl<S: Shape> PartialGraph<S> {
pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
(self.build)(builder)
}
#[must_use]
pub fn attributes(&self) -> &Attributes {
&self.attributes
}
#[must_use]
pub fn with_attributes(mut self, attributes: Attributes) -> Self {
self.attributes = attributes;
self
}
#[must_use]
pub fn add_attributes(mut self, attributes: Attributes) -> Self {
self.attributes = self.attributes.and(attributes);
self
}
#[must_use]
pub fn named(self, name: impl Into<String>) -> Self {
self.add_attributes(Attributes::named(name))
}
}
impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PartialGraph")
.field("attributes", &self.attributes)
.finish_non_exhaustive()
}
}
pub type ImportedGraph<S> = PartialGraph<S>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct FusedExecutionConfig {
pub event_limit: usize,
}
impl Default for FusedExecutionConfig {
fn default() -> Self {
Self {
event_limit: 100_000_000,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct AsyncBoundaryExecutionConfig {
pub fused: FusedExecutionConfig,
pub buffer_size: usize,
}
impl Default for AsyncBoundaryExecutionConfig {
fn default() -> Self {
Self {
fused: FusedExecutionConfig::default(),
buffer_size: 16,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FusedExecutionReport<T> {
pub output: Vec<T>,
pub events: usize,
pub async_boundary_crossings: usize,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct FusedTerminalReport<T> {
pub result: T,
pub events: usize,
pub async_boundary_crossings: usize,
}