Skip to main content

datum/graph/
builder.rs

1use super::*;
2use crate::Attributes;
3
4type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
5
6#[derive(Clone, Debug)]
7struct PortRecord {
8    kind: PortKind,
9    type_id: TypeId,
10    type_name: &'static str,
11    name: Arc<str>,
12}
13
14#[derive(Clone, Debug)]
15pub(super) struct Edge {
16    pub(super) outlet: PortId,
17    pub(super) inlet: PortId,
18}
19
20#[derive(Clone)]
21pub(super) struct StageRecord {
22    pub(super) spec: StageSpec,
23    pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
24}
25
26impl std::fmt::Debug for StageRecord {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("StageRecord")
29            .field("spec", &self.spec)
30            .field("has_logic", &self.logic_factory.is_some())
31            .finish()
32    }
33}
34
35#[derive(Debug, Default)]
36pub struct GraphBuilder {
37    allocator: PortAllocator,
38    ports: HashMap<PortId, PortRecord>,
39    stages: Vec<StageRecord>,
40    edges: Vec<Edge>,
41    errors: Vec<StreamError>,
42}
43
44impl GraphBuilder {
45    #[must_use]
46    pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
47        self.add_with_attributes(stage, Attributes::default())
48    }
49
50    #[must_use]
51    pub fn add_with_attributes<G: GraphStage>(
52        &mut self,
53        stage: G,
54        attributes: Attributes,
55    ) -> G::Shape {
56        let shape = stage.allocate_shape(&mut self.allocator);
57        let inlets = shape.inlets();
58        let outlets = shape.outlets();
59        self.ports.reserve(inlets.len() + outlets.len());
60
61        for inlet in &inlets {
62            self.register_inlet(inlet);
63        }
64        for outlet in &outlets {
65            self.register_outlet(outlet);
66        }
67        let spec = stage
68            .stage_spec_with_ports(&shape, inlets, outlets)
69            .add_attributes(attributes);
70        let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
71            let shape_clone = shape.clone();
72            Some(Arc::new(move || stage.create_logic(&shape_clone))
73                as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
74        } else {
75            None
76        };
77        self.stages.push(StageRecord {
78            spec,
79            logic_factory,
80        });
81        shape
82    }
83
84    #[must_use]
85    pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
86        self.add_with_attributes(stage, Attributes::named(name))
87    }
88
89    pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
90        self.connect_any(outlet.erase(), inlet.erase())
91    }
92
93    pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
94        match self.validate_connection(&outlet, &inlet) {
95            Ok(()) => {
96                self.edges.push(Edge {
97                    outlet: outlet.id(),
98                    inlet: inlet.id(),
99                });
100                Ok(())
101            }
102            Err(error) => {
103                self.errors.push(error.clone());
104                Err(error)
105            }
106        }
107    }
108
109    pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
110        graph.build(self)
111    }
112
113    fn register_inlet(&mut self, inlet: &AnyInlet) {
114        self.ports.insert(
115            inlet.id(),
116            PortRecord {
117                kind: PortKind::Inlet,
118                type_id: inlet.type_id(),
119                type_name: inlet.type_name(),
120                name: Arc::clone(&inlet.name),
121            },
122        );
123    }
124
125    fn register_outlet(&mut self, outlet: &AnyOutlet) {
126        self.ports.insert(
127            outlet.id(),
128            PortRecord {
129                kind: PortKind::Outlet,
130                type_id: outlet.type_id(),
131                type_name: outlet.type_name(),
132                name: Arc::clone(&outlet.name),
133            },
134        );
135    }
136
137    fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
138        let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
139            StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
140        })?;
141        let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
142            StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
143        })?;
144
145        if outlet_record.kind != PortKind::Outlet {
146            return Err(StreamError::GraphValidation(format!(
147                "{} is not an outlet",
148                outlet_record.name
149            )));
150        }
151        if inlet_record.kind != PortKind::Inlet {
152            return Err(StreamError::GraphValidation(format!(
153                "{} is not an inlet",
154                inlet_record.name
155            )));
156        }
157        if outlet_record.type_id != inlet_record.type_id {
158            return Err(StreamError::GraphValidation(format!(
159                "cannot connect outlet {} ({}) to inlet {} ({})",
160                outlet_record.name,
161                outlet_record.type_name,
162                inlet_record.name,
163                inlet_record.type_name
164            )));
165        }
166        if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
167            return Err(StreamError::GraphValidation(format!(
168                "outlet {} is already connected",
169                outlet_record.name
170            )));
171        }
172        if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
173            return Err(StreamError::GraphValidation(format!(
174                "inlet {} is already connected",
175                inlet_record.name
176            )));
177        }
178
179        Ok(())
180    }
181
182    fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
183        let mut errors = self.errors;
184        let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
185        let connected_outlets: HashSet<PortId> =
186            self.edges.iter().map(|edge| edge.outlet).collect();
187
188        let open_inlets: HashSet<PortId> = self
189            .ports
190            .iter()
191            .filter_map(|(id, port)| {
192                (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
193            })
194            .collect();
195        let open_outlets: HashSet<PortId> = self
196            .ports
197            .iter()
198            .filter_map(|(id, port)| {
199                (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
200            })
201            .collect();
202
203        let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
204        let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
205
206        for inlet in shape.inlets() {
207            match self.ports.get(&inlet.id()) {
208                Some(port)
209                    if port.kind == PortKind::Inlet
210                        && port.type_id == inlet.type_id()
211                        && port.name.as_ref() == inlet.name() => {}
212                Some(port) if port.kind == PortKind::Inlet => {
213                    errors.push(StreamError::GraphValidation(format!(
214                        "result shape inlet {} does not match registered inlet {} ({})",
215                        inlet.name(),
216                        port.name,
217                        port.type_name
218                    )));
219                }
220                Some(port) => errors.push(StreamError::GraphValidation(format!(
221                    "result shape references non-inlet port {}",
222                    port.name
223                ))),
224                None => errors.push(StreamError::GraphValidation(format!(
225                    "result shape references unknown inlet {}",
226                    inlet.name()
227                ))),
228            }
229        }
230        for outlet in shape.outlets() {
231            match self.ports.get(&outlet.id()) {
232                Some(port)
233                    if port.kind == PortKind::Outlet
234                        && port.type_id == outlet.type_id()
235                        && port.name.as_ref() == outlet.name() => {}
236                Some(port) if port.kind == PortKind::Outlet => {
237                    errors.push(StreamError::GraphValidation(format!(
238                        "result shape outlet {} does not match registered outlet {} ({})",
239                        outlet.name(),
240                        port.name,
241                        port.type_name
242                    )));
243                }
244                Some(port) => errors.push(StreamError::GraphValidation(format!(
245                    "result shape references non-outlet port {}",
246                    port.name
247                ))),
248                None => errors.push(StreamError::GraphValidation(format!(
249                    "result shape references unknown outlet {}",
250                    outlet.name()
251                ))),
252            }
253        }
254
255        if open_inlets != result_inlets {
256            errors.push(StreamError::GraphValidation(format!(
257                "result shape inlets do not match open inlets: open={:?}, result={:?}",
258                describe_ports(&self.ports, &open_inlets),
259                describe_ports(&self.ports, &result_inlets)
260            )));
261        }
262        if open_outlets != result_outlets {
263            errors.push(StreamError::GraphValidation(format!(
264                "result shape outlets do not match open outlets: open={:?}, result={:?}",
265                describe_ports(&self.ports, &open_outlets),
266                describe_ports(&self.ports, &result_outlets)
267            )));
268        }
269
270        if graph_has_cycle(&self.stages, &self.edges) {
271            errors.push(StreamError::GraphValidation(
272                "graph contains a cycle; Datum still rejects cyclic fused graphs until WP-16 adds a demand-aware graph interpreter".into(),
273            ));
274        }
275
276        if !errors.is_empty() {
277            return Err(StreamError::GraphValidation(
278                errors
279                    .into_iter()
280                    .map(|error| error.to_string())
281                    .collect::<Vec<_>>()
282                    .join("; "),
283            ));
284        }
285
286        let segments = compute_segments(&self.stages);
287        Ok(GraphBlueprint {
288            shape,
289            stages: self.stages,
290            edges: self.edges,
291            segments,
292            attributes: Attributes::default(),
293        })
294    }
295}
296
297fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
298    let mut names = ids
299        .iter()
300        .map(|id| {
301            ports
302                .get(id)
303                .map(|port| port.name.as_ref().to_owned())
304                .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
305        })
306        .collect::<Vec<_>>();
307    names.sort();
308    names
309}
310
311/// Detects a cycle in the stage connectivity graph via Kahn's algorithm.
312///
313/// The fused executor drives elements with mutually recursive `deliver`/`emit`
314/// calls (one stack frame per outlet→inlet hop), so a cyclic blueprint would
315/// recurse without bound and overflow the stack at run time. Reject cycles at
316/// build time instead.
317fn graph_has_cycle(stages: &[StageRecord], edges: &[Edge]) -> bool {
318    let mut stage_of_inlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
319    let mut stage_of_outlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
320    for (index, stage) in stages.iter().enumerate() {
321        for inlet in &stage.spec.inlets {
322            stage_of_inlet.insert(inlet.id(), index);
323        }
324        for outlet in &stage.spec.outlets {
325            stage_of_outlet.insert(outlet.id(), index);
326        }
327    }
328
329    const NO_SUCCESSOR: usize = usize::MAX;
330    let mut first_successor = vec![NO_SUCCESSOR; stages.len()];
331    let mut successor_to = Vec::with_capacity(edges.len());
332    let mut next_successor = Vec::with_capacity(edges.len());
333    let mut indegree: Vec<usize> = vec![0; stages.len()];
334    for edge in edges {
335        if let (Some(&from), Some(&to)) = (
336            stage_of_outlet.get(&edge.outlet),
337            stage_of_inlet.get(&edge.inlet),
338        ) {
339            successor_to.push(to);
340            next_successor.push(first_successor[from]);
341            first_successor[from] = successor_to.len() - 1;
342            indegree[to] += 1;
343        }
344    }
345
346    // Kahn's algorithm; visit order is irrelevant for cycle detection, so a Vec
347    // stack is cheaper than a VecDeque queue.
348    let mut stack: Vec<usize> = (0..stages.len()).filter(|&i| indegree[i] == 0).collect();
349    let mut visited = 0_usize;
350    while let Some(stage) = stack.pop() {
351        visited += 1;
352        let mut successor = first_successor[stage];
353        while successor != NO_SUCCESSOR {
354            let next = successor_to[successor];
355            indegree[next] -= 1;
356            if indegree[next] == 0 {
357                stack.push(next);
358            }
359            successor = next_successor[successor];
360        }
361    }
362
363    visited != stages.len()
364}
365
366fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
367    let mut segments = Vec::with_capacity(1);
368    let mut current = Vec::with_capacity(stages.len());
369
370    for (index, stage) in stages.iter().enumerate() {
371        if stage.spec.async_boundary && !current.is_empty() {
372            segments.push(FusedSegment {
373                stage_indices: std::mem::take(&mut current),
374            });
375        }
376        current.push(index);
377        if stage.spec.async_boundary {
378            segments.push(FusedSegment {
379                stage_indices: std::mem::take(&mut current),
380            });
381        }
382    }
383
384    if !current.is_empty() {
385        segments.push(FusedSegment {
386            stage_indices: current,
387        });
388    }
389
390    segments
391}
392
393pub struct GraphDsl;
394
395impl GraphDsl {
396    pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
397    where
398        S: Shape,
399        F: FnOnce(&mut GraphBuilder) -> S,
400    {
401        let mut builder = GraphBuilder::default();
402        let shape = build(&mut builder);
403        builder.finish(shape)
404    }
405
406    pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
407    where
408        S: Shape,
409        F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
410    {
411        let mut builder = GraphBuilder::default();
412        let shape = build(&mut builder)?;
413        builder.finish(shape)
414    }
415
416    pub fn partial<S, F>(build: F) -> PartialGraph<S>
417    where
418        S: Shape,
419        F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
420    {
421        PartialGraph {
422            build: Arc::new(build),
423            attributes: Attributes::default(),
424        }
425    }
426}
427
428pub trait Graph {
429    type Shape: Shape;
430
431    fn shape(&self) -> Self::Shape;
432}
433
434#[derive(Clone, Debug)]
435pub struct FusedSegment {
436    stage_indices: Vec<usize>,
437}
438
439impl FusedSegment {
440    #[must_use]
441    pub fn stage_indices(&self) -> &[usize] {
442        &self.stage_indices
443    }
444}
445
446pub struct GraphBlueprint<S: Shape> {
447    pub(super) shape: S,
448    pub(super) stages: Vec<StageRecord>,
449    pub(super) edges: Vec<Edge>,
450    pub(super) segments: Vec<FusedSegment>,
451    pub(super) attributes: Attributes,
452}
453
454impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
455    fn clone(&self) -> Self {
456        Self {
457            shape: self.shape.clone(),
458            stages: self.stages.clone(),
459            edges: self.edges.clone(),
460            segments: self.segments.clone(),
461            attributes: self.attributes.clone(),
462        }
463    }
464}
465
466impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
467    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468        f.debug_struct("GraphBlueprint")
469            .field("shape", &self.shape)
470            .field("stages", &self.stages)
471            .field("edges", &self.edges)
472            .field("segments", &self.segments)
473            .field("attributes", &self.attributes)
474            .finish()
475    }
476}
477
478impl<S: Shape> GraphBlueprint<S> {
479    #[must_use]
480    pub fn shape(&self) -> S {
481        self.shape.clone()
482    }
483
484    #[must_use]
485    pub fn stage_count(&self) -> usize {
486        self.stages.len()
487    }
488
489    #[must_use]
490    pub fn edge_count(&self) -> usize {
491        self.edges.len()
492    }
493
494    #[must_use]
495    pub fn segments(&self) -> &[FusedSegment] {
496        &self.segments
497    }
498
499    #[must_use]
500    pub fn attributes(&self) -> &Attributes {
501        &self.attributes
502    }
503
504    #[must_use]
505    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
506        self.attributes = attributes;
507        self
508    }
509
510    #[must_use]
511    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
512        self.attributes = self.attributes.and(attributes);
513        self
514    }
515
516    #[must_use]
517    pub fn named(self, name: impl Into<String>) -> Self {
518        self.add_attributes(Attributes::named(name))
519    }
520}
521
522impl<S: Shape> Graph for GraphBlueprint<S> {
523    type Shape = S;
524
525    fn shape(&self) -> Self::Shape {
526        self.shape()
527    }
528}
529
530#[derive(Clone)]
531pub struct PartialGraph<S: Shape> {
532    build: Arc<PartialGraphBuilder<S>>,
533    attributes: Attributes,
534}
535
536impl<S: Shape> PartialGraph<S> {
537    pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
538        (self.build)(builder)
539    }
540
541    #[must_use]
542    pub fn attributes(&self) -> &Attributes {
543        &self.attributes
544    }
545
546    #[must_use]
547    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
548        self.attributes = attributes;
549        self
550    }
551
552    #[must_use]
553    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
554        self.attributes = self.attributes.and(attributes);
555        self
556    }
557
558    #[must_use]
559    pub fn named(self, name: impl Into<String>) -> Self {
560        self.add_attributes(Attributes::named(name))
561    }
562}
563
564impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
565    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566        f.debug_struct("PartialGraph")
567            .field("attributes", &self.attributes)
568            .finish_non_exhaustive()
569    }
570}
571
572pub type ImportedGraph<S> = PartialGraph<S>;
573
574#[derive(Clone, Copy, Debug, PartialEq, Eq)]
575pub struct FusedExecutionConfig {
576    pub event_limit: usize,
577}
578
579impl Default for FusedExecutionConfig {
580    fn default() -> Self {
581        Self {
582            event_limit: 100_000_000,
583        }
584    }
585}
586
587/// Execution settings for the current graph async-boundary benchmark path.
588///
589/// This path validates a typed-linear graph and uses Ractor-backed async
590/// regions with bounded handoff queues to measure real boundary crossing cost.
591#[derive(Clone, Copy, Debug, PartialEq, Eq)]
592pub struct AsyncBoundaryExecutionConfig {
593    pub fused: FusedExecutionConfig,
594    pub buffer_size: usize,
595}
596
597impl Default for AsyncBoundaryExecutionConfig {
598    fn default() -> Self {
599        Self {
600            fused: FusedExecutionConfig::default(),
601            buffer_size: 16,
602        }
603    }
604}
605
606#[derive(Clone, Debug, PartialEq, Eq)]
607pub struct FusedExecutionReport<T> {
608    pub output: Vec<T>,
609    pub events: usize,
610    pub async_boundary_crossings: usize,
611}
612
613#[derive(Clone, Debug, PartialEq, Eq)]
614pub struct FusedTerminalReport<T> {
615    pub result: T,
616    pub events: usize,
617    pub async_boundary_crossings: usize,
618}