Skip to main content

datum/graph/
builder.rs

1//! Graph construction: the builder, the `GraphDsl` entry points, and the
2//! immutable `GraphBlueprint` they produce.
3//!
4//! Building a graph is side-effect-free (blueprint vs. materialization): `add`
5//! registers a stage's ports, `connect`/`wire` records an edge after
6//! `validate_connection` checks port kind, element `TypeId`, and single-use, and
7//! `finish` verifies the returned [`Shape`] matches the open ports. Execution
8//! starts only when a `run_*` method (defined in the `executor` module) is
9//! called on the resulting [`GraphBlueprint`].
10
11use super::*;
12use crate::Attributes;
13
14type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
15
16#[derive(Clone, Debug)]
17struct PortRecord {
18    kind: PortKind,
19    type_id: TypeId,
20    type_name: &'static str,
21    name: Arc<str>,
22}
23
24#[derive(Clone, Debug)]
25pub(super) struct Edge {
26    pub(super) outlet: PortId,
27    pub(super) inlet: PortId,
28}
29
30#[derive(Clone)]
31pub(super) struct StageRecord {
32    pub(super) spec: StageSpec,
33    pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
34}
35
36impl std::fmt::Debug for StageRecord {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("StageRecord")
39            .field("spec", &self.spec)
40            .field("has_logic", &self.logic_factory.is_some())
41            .finish()
42    }
43}
44
45/// Mutable scratch state for assembling a graph inside a `GraphDsl` closure.
46///
47/// Tracks registered stages, their ports, the edges wired between them, and any
48/// deferred wiring errors (raised by `wire`, which collects rather than
49/// short-circuits). `finish` consumes it into a validated [`GraphBlueprint`].
50#[derive(Debug, Default)]
51pub struct GraphBuilder {
52    allocator: PortAllocator,
53    ports: HashMap<PortId, PortRecord>,
54    stages: Vec<StageRecord>,
55    edges: Vec<Edge>,
56    errors: Vec<StreamError>,
57}
58
59impl GraphBuilder {
60    #[must_use]
61    pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
62        self.add_with_attributes(stage, Attributes::default())
63    }
64
65    #[must_use]
66    pub fn add_with_attributes<G: GraphStage>(
67        &mut self,
68        stage: G,
69        attributes: Attributes,
70    ) -> G::Shape {
71        let shape = stage.allocate_shape(&mut self.allocator);
72        let inlets = shape.inlets();
73        let outlets = shape.outlets();
74        self.ports.reserve(inlets.len() + outlets.len());
75
76        for inlet in &inlets {
77            self.register_inlet(inlet);
78        }
79        for outlet in &outlets {
80            self.register_outlet(outlet);
81        }
82        let spec = stage
83            .stage_spec_with_ports(&shape, inlets, outlets)
84            .add_attributes(attributes);
85        let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
86            let shape_clone = shape.clone();
87            Some(Arc::new(move || stage.create_logic(&shape_clone))
88                as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
89        } else {
90            None
91        };
92        self.stages.push(StageRecord {
93            spec,
94            logic_factory,
95        });
96        shape
97    }
98
99    #[must_use]
100    pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
101        self.add_with_attributes(stage, Attributes::named(name))
102    }
103
104    pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
105        self.connect_any(outlet.erase(), inlet.erase())
106    }
107
108    pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
109        match self.connect_any_unrecorded(outlet, inlet) {
110            Ok(()) => Ok(()),
111            Err(error) => self.record_error(error),
112        }
113    }
114
115    pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
116        graph.build(self)
117    }
118
119    pub fn wire<W: WireSpec>(&mut self, spec: W) -> &mut Self {
120        if let Err(error) = spec.apply(self) {
121            let _ = self.record_error(error);
122        }
123        self
124    }
125
126    pub fn try_wire<W: WireSpec>(&mut self, spec: W) -> StreamResult<&mut Self> {
127        match spec.apply(self) {
128            Ok(()) => Ok(self),
129            Err(error) => {
130                self.errors.push(error.clone());
131                Err(error)
132            }
133        }
134    }
135
136    pub(super) fn connect_any_unrecorded(
137        &mut self,
138        outlet: AnyOutlet,
139        inlet: AnyInlet,
140    ) -> StreamResult<()> {
141        self.validate_connection(&outlet, &inlet)?;
142        self.edges.push(Edge {
143            outlet: outlet.id(),
144            inlet: inlet.id(),
145        });
146        Ok(())
147    }
148
149    pub(super) fn is_outlet_connected(&self, outlet: &AnyOutlet) -> bool {
150        self.edges.iter().any(|edge| edge.outlet == outlet.id())
151    }
152
153    pub(super) fn is_inlet_connected(&self, inlet: &AnyInlet) -> bool {
154        self.edges.iter().any(|edge| edge.inlet == inlet.id())
155    }
156
157    fn record_error(&mut self, error: StreamError) -> StreamResult<()> {
158        self.errors.push(error.clone());
159        Err(error)
160    }
161
162    fn register_inlet(&mut self, inlet: &AnyInlet) {
163        self.ports.insert(
164            inlet.id(),
165            PortRecord {
166                kind: PortKind::Inlet,
167                type_id: inlet.type_id(),
168                type_name: inlet.type_name(),
169                name: Arc::clone(&inlet.name),
170            },
171        );
172    }
173
174    fn register_outlet(&mut self, outlet: &AnyOutlet) {
175        self.ports.insert(
176            outlet.id(),
177            PortRecord {
178                kind: PortKind::Outlet,
179                type_id: outlet.type_id(),
180                type_name: outlet.type_name(),
181                name: Arc::clone(&outlet.name),
182            },
183        );
184    }
185
186    fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
187        let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
188            StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
189        })?;
190        let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
191            StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
192        })?;
193
194        if outlet_record.kind != PortKind::Outlet {
195            return Err(StreamError::GraphValidation(format!(
196                "{} is not an outlet",
197                outlet_record.name
198            )));
199        }
200        if inlet_record.kind != PortKind::Inlet {
201            return Err(StreamError::GraphValidation(format!(
202                "{} is not an inlet",
203                inlet_record.name
204            )));
205        }
206        if outlet_record.type_id != inlet_record.type_id {
207            return Err(StreamError::GraphValidation(format!(
208                "cannot connect outlet {} ({}) to inlet {} ({})",
209                outlet_record.name,
210                outlet_record.type_name,
211                inlet_record.name,
212                inlet_record.type_name
213            )));
214        }
215        if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
216            return Err(StreamError::GraphValidation(format!(
217                "outlet {} is already connected",
218                outlet_record.name
219            )));
220        }
221        if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
222            return Err(StreamError::GraphValidation(format!(
223                "inlet {} is already connected",
224                inlet_record.name
225            )));
226        }
227
228        Ok(())
229    }
230
231    fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
232        let mut errors = self.errors;
233        let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
234        let connected_outlets: HashSet<PortId> =
235            self.edges.iter().map(|edge| edge.outlet).collect();
236
237        let open_inlets: HashSet<PortId> = self
238            .ports
239            .iter()
240            .filter_map(|(id, port)| {
241                (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
242            })
243            .collect();
244        let open_outlets: HashSet<PortId> = self
245            .ports
246            .iter()
247            .filter_map(|(id, port)| {
248                (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
249            })
250            .collect();
251
252        let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
253        let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
254
255        for inlet in shape.inlets() {
256            match self.ports.get(&inlet.id()) {
257                Some(port)
258                    if port.kind == PortKind::Inlet
259                        && port.type_id == inlet.type_id()
260                        && port.name.as_ref() == inlet.name() => {}
261                Some(port) if port.kind == PortKind::Inlet => {
262                    errors.push(StreamError::GraphValidation(format!(
263                        "result shape inlet {} does not match registered inlet {} ({})",
264                        inlet.name(),
265                        port.name,
266                        port.type_name
267                    )));
268                }
269                Some(port) => errors.push(StreamError::GraphValidation(format!(
270                    "result shape references non-inlet port {}",
271                    port.name
272                ))),
273                None => errors.push(StreamError::GraphValidation(format!(
274                    "result shape references unknown inlet {}",
275                    inlet.name()
276                ))),
277            }
278        }
279        for outlet in shape.outlets() {
280            match self.ports.get(&outlet.id()) {
281                Some(port)
282                    if port.kind == PortKind::Outlet
283                        && port.type_id == outlet.type_id()
284                        && port.name.as_ref() == outlet.name() => {}
285                Some(port) if port.kind == PortKind::Outlet => {
286                    errors.push(StreamError::GraphValidation(format!(
287                        "result shape outlet {} does not match registered outlet {} ({})",
288                        outlet.name(),
289                        port.name,
290                        port.type_name
291                    )));
292                }
293                Some(port) => errors.push(StreamError::GraphValidation(format!(
294                    "result shape references non-outlet port {}",
295                    port.name
296                ))),
297                None => errors.push(StreamError::GraphValidation(format!(
298                    "result shape references unknown outlet {}",
299                    outlet.name()
300                ))),
301            }
302        }
303
304        if open_inlets != result_inlets {
305            errors.push(StreamError::GraphValidation(format!(
306                "result shape inlets do not match open inlets: open={:?}, result={:?}",
307                describe_ports(&self.ports, &open_inlets),
308                describe_ports(&self.ports, &result_inlets)
309            )));
310        }
311        if open_outlets != result_outlets {
312            errors.push(StreamError::GraphValidation(format!(
313                "result shape outlets do not match open outlets: open={:?}, result={:?}",
314                describe_ports(&self.ports, &open_outlets),
315                describe_ports(&self.ports, &result_outlets)
316            )));
317        }
318
319        if !errors.is_empty() {
320            return Err(StreamError::GraphValidation(
321                errors
322                    .into_iter()
323                    .map(|error| error.to_string())
324                    .collect::<Vec<_>>()
325                    .join("; "),
326            ));
327        }
328
329        let segments = compute_segments(&self.stages);
330        Ok(GraphBlueprint {
331            shape,
332            stages: self.stages,
333            edges: self.edges,
334            segments,
335            attributes: Attributes::default(),
336        })
337    }
338}
339
340fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
341    let mut names = ids
342        .iter()
343        .map(|id| {
344            ports
345                .get(id)
346                .map(|port| port.name.as_ref().to_owned())
347                .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
348        })
349        .collect::<Vec<_>>();
350    names.sort();
351    names
352}
353
354fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
355    let mut segments = Vec::with_capacity(1);
356    let mut current = Vec::with_capacity(stages.len());
357
358    for (index, stage) in stages.iter().enumerate() {
359        if stage.spec.async_boundary && !current.is_empty() {
360            segments.push(FusedSegment {
361                stage_indices: std::mem::take(&mut current),
362            });
363        }
364        current.push(index);
365        if stage.spec.async_boundary {
366            segments.push(FusedSegment {
367                stage_indices: std::mem::take(&mut current),
368            });
369        }
370    }
371
372    if !current.is_empty() {
373        segments.push(FusedSegment {
374            stage_indices: current,
375        });
376    }
377
378    segments
379}
380
381/// Entry points for building graphs. `create` takes a closure returning the
382/// shape directly, `try_create` a closure returning `StreamResult<Shape>` (so
383/// `?` works on `connect`/`try_wire`), and `partial` builds a reusable
384/// [`PartialGraph`] fragment.
385pub struct GraphDsl;
386
387impl GraphDsl {
388    pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
389    where
390        S: Shape,
391        F: FnOnce(&mut GraphBuilder) -> S,
392    {
393        let mut builder = GraphBuilder::default();
394        let shape = build(&mut builder);
395        builder.finish(shape)
396    }
397
398    pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
399    where
400        S: Shape,
401        F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
402    {
403        let mut builder = GraphBuilder::default();
404        let shape = build(&mut builder)?;
405        builder.finish(shape)
406    }
407
408    pub fn partial<S, F>(build: F) -> PartialGraph<S>
409    where
410        S: Shape,
411        F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
412    {
413        PartialGraph {
414            build: Arc::new(build),
415            attributes: Attributes::default(),
416        }
417    }
418}
419
420/// Anything exposing a graph [`Shape`] — implemented by [`GraphBlueprint`].
421pub trait Graph {
422    type Shape: Shape;
423
424    fn shape(&self) -> Self::Shape;
425}
426
427/// A maximal run of stages with no async boundary between them, executed
428/// together by the fused executor. Boundaries (`AsyncBoundary` stages) split a
429/// graph into consecutive segments.
430#[derive(Clone, Debug)]
431pub struct FusedSegment {
432    stage_indices: Vec<usize>,
433}
434
435impl FusedSegment {
436    #[must_use]
437    pub fn stage_indices(&self) -> &[usize] {
438        &self.stage_indices
439    }
440}
441
442/// An immutable, validated graph ready to run. Produced by `GraphDsl::create`/
443/// `try_create`. Carries the external [`Shape`], the stages, the wired edges,
444/// the precomputed fused segments, and graph-level [`Attributes`]. The `run_*`
445/// methods are defined in the `executor` module; running one never mutates the
446/// blueprint, so it can be reused and run concurrently.
447pub struct GraphBlueprint<S: Shape> {
448    pub(super) shape: S,
449    pub(super) stages: Vec<StageRecord>,
450    pub(super) edges: Vec<Edge>,
451    pub(super) segments: Vec<FusedSegment>,
452    pub(super) attributes: Attributes,
453}
454
455impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
456    fn clone(&self) -> Self {
457        Self {
458            shape: self.shape.clone(),
459            stages: self.stages.clone(),
460            edges: self.edges.clone(),
461            segments: self.segments.clone(),
462            attributes: self.attributes.clone(),
463        }
464    }
465}
466
467impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
468    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469        f.debug_struct("GraphBlueprint")
470            .field("shape", &self.shape)
471            .field("stages", &self.stages)
472            .field("edges", &self.edges)
473            .field("segments", &self.segments)
474            .field("attributes", &self.attributes)
475            .finish()
476    }
477}
478
479impl<S: Shape> GraphBlueprint<S> {
480    #[must_use]
481    pub fn shape(&self) -> S {
482        self.shape.clone()
483    }
484
485    #[must_use]
486    pub fn stage_count(&self) -> usize {
487        self.stages.len()
488    }
489
490    #[must_use]
491    pub fn edge_count(&self) -> usize {
492        self.edges.len()
493    }
494
495    #[must_use]
496    pub fn segments(&self) -> &[FusedSegment] {
497        &self.segments
498    }
499
500    #[must_use]
501    pub fn attributes(&self) -> &Attributes {
502        &self.attributes
503    }
504
505    #[must_use]
506    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
507        self.attributes = attributes;
508        self
509    }
510
511    #[must_use]
512    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
513        self.attributes = self.attributes.and(attributes);
514        self
515    }
516
517    #[must_use]
518    pub fn named(self, name: impl Into<String>) -> Self {
519        self.add_attributes(Attributes::named(name))
520    }
521}
522
523impl<S: Shape> Graph for GraphBlueprint<S> {
524    type Shape = S;
525
526    fn shape(&self) -> Self::Shape {
527        self.shape()
528    }
529}
530
531/// A reusable graph fragment: a builder closure plus its [`Shape`], importable
532/// into multiple parent graphs via [`GraphBuilder::import`]. Still a blueprint —
533/// the closure runs (allocating fresh ports) each time it is imported.
534/// Aliased as [`ImportedGraph`].
535#[derive(Clone)]
536pub struct PartialGraph<S: Shape> {
537    build: Arc<PartialGraphBuilder<S>>,
538    attributes: Attributes,
539}
540
541impl<S: Shape> PartialGraph<S> {
542    pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
543        (self.build)(builder)
544    }
545
546    #[must_use]
547    pub fn attributes(&self) -> &Attributes {
548        &self.attributes
549    }
550
551    #[must_use]
552    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
553        self.attributes = attributes;
554        self
555    }
556
557    #[must_use]
558    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
559        self.attributes = self.attributes.and(attributes);
560        self
561    }
562
563    #[must_use]
564    pub fn named(self, name: impl Into<String>) -> Self {
565        self.add_attributes(Attributes::named(name))
566    }
567}
568
569impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
570    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571        f.debug_struct("PartialGraph")
572            .field("attributes", &self.attributes)
573            .finish_non_exhaustive()
574    }
575}
576
577pub type ImportedGraph<S> = PartialGraph<S>;
578
579/// Execution bound for the fused executor. `event_limit` caps the number of
580/// push/pull events a single run may take, so an unproductive cycle surfaces
581/// [`StreamError::EventLimitExceeded`] instead of hanging. Defaults to 100M.
582#[derive(Clone, Copy, Debug, PartialEq, Eq)]
583pub struct FusedExecutionConfig {
584    pub event_limit: usize,
585}
586
587impl Default for FusedExecutionConfig {
588    fn default() -> Self {
589        Self {
590            event_limit: 100_000_000,
591        }
592    }
593}
594
595/// Execution settings for the current graph async-boundary benchmark path.
596///
597/// This path validates a typed-linear graph and uses Ractor-backed async
598/// regions with bounded handoff queues to measure real boundary crossing cost.
599#[derive(Clone, Copy, Debug, PartialEq, Eq)]
600pub struct AsyncBoundaryExecutionConfig {
601    pub fused: FusedExecutionConfig,
602    pub buffer_size: usize,
603}
604
605impl Default for AsyncBoundaryExecutionConfig {
606    fn default() -> Self {
607        Self {
608            fused: FusedExecutionConfig::default(),
609            buffer_size: 16,
610        }
611    }
612}
613
614/// Result of a collecting run: the `output` vector plus instrumentation
615/// (`events` processed, `async_boundary_crossings`). Returned by the
616/// `*_with_input_report` methods.
617#[derive(Clone, Debug, PartialEq, Eq)]
618pub struct FusedExecutionReport<T> {
619    pub output: Vec<T>,
620    pub events: usize,
621    pub async_boundary_crossings: usize,
622}
623
624/// Result of a terminal (count/fold) run: the reduced `result` plus the same
625/// instrumentation as [`FusedExecutionReport`]. Returned by the
626/// `*_count_*_report` / `*_fold_*_report` methods.
627#[derive(Clone, Debug, PartialEq, Eq)]
628pub struct FusedTerminalReport<T> {
629    pub result: T,
630    pub events: usize,
631    pub async_boundary_crossings: usize,
632}