Skip to main content

datum/graph/
builder.rs

1use super::*;
2use crate::Attributes;
3
4type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
5
6#[derive(Clone, Debug)]
7struct PortRecord {
8    kind: PortKind,
9    type_id: TypeId,
10    type_name: &'static str,
11    name: Arc<str>,
12}
13
14#[derive(Clone, Debug)]
15pub(super) struct Edge {
16    pub(super) outlet: PortId,
17    pub(super) inlet: PortId,
18}
19
20#[derive(Clone)]
21pub(super) struct StageRecord {
22    pub(super) spec: StageSpec,
23    pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
24}
25
26impl std::fmt::Debug for StageRecord {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("StageRecord")
29            .field("spec", &self.spec)
30            .field("has_logic", &self.logic_factory.is_some())
31            .finish()
32    }
33}
34
35#[derive(Debug, Default)]
36pub struct GraphBuilder {
37    allocator: PortAllocator,
38    ports: HashMap<PortId, PortRecord>,
39    stages: Vec<StageRecord>,
40    edges: Vec<Edge>,
41    errors: Vec<StreamError>,
42}
43
44impl GraphBuilder {
45    #[must_use]
46    pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
47        self.add_with_attributes(stage, Attributes::default())
48    }
49
50    #[must_use]
51    pub fn add_with_attributes<G: GraphStage>(
52        &mut self,
53        stage: G,
54        attributes: Attributes,
55    ) -> G::Shape {
56        let shape = stage.allocate_shape(&mut self.allocator);
57        let inlets = shape.inlets();
58        let outlets = shape.outlets();
59        self.ports.reserve(inlets.len() + outlets.len());
60
61        for inlet in &inlets {
62            self.register_inlet(inlet);
63        }
64        for outlet in &outlets {
65            self.register_outlet(outlet);
66        }
67        let spec = stage
68            .stage_spec_with_ports(&shape, inlets, outlets)
69            .add_attributes(attributes);
70        let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
71            let shape_clone = shape.clone();
72            Some(Arc::new(move || stage.create_logic(&shape_clone))
73                as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
74        } else {
75            None
76        };
77        self.stages.push(StageRecord {
78            spec,
79            logic_factory,
80        });
81        shape
82    }
83
84    #[must_use]
85    pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
86        self.add_with_attributes(stage, Attributes::named(name))
87    }
88
89    pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
90        self.connect_any(outlet.erase(), inlet.erase())
91    }
92
93    pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
94        match self.validate_connection(&outlet, &inlet) {
95            Ok(()) => {
96                self.edges.push(Edge {
97                    outlet: outlet.id(),
98                    inlet: inlet.id(),
99                });
100                Ok(())
101            }
102            Err(error) => {
103                self.errors.push(error.clone());
104                Err(error)
105            }
106        }
107    }
108
109    pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
110        graph.build(self)
111    }
112
113    fn register_inlet(&mut self, inlet: &AnyInlet) {
114        self.ports.insert(
115            inlet.id(),
116            PortRecord {
117                kind: PortKind::Inlet,
118                type_id: inlet.type_id(),
119                type_name: inlet.type_name(),
120                name: Arc::clone(&inlet.name),
121            },
122        );
123    }
124
125    fn register_outlet(&mut self, outlet: &AnyOutlet) {
126        self.ports.insert(
127            outlet.id(),
128            PortRecord {
129                kind: PortKind::Outlet,
130                type_id: outlet.type_id(),
131                type_name: outlet.type_name(),
132                name: Arc::clone(&outlet.name),
133            },
134        );
135    }
136
137    fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
138        let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
139            StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
140        })?;
141        let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
142            StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
143        })?;
144
145        if outlet_record.kind != PortKind::Outlet {
146            return Err(StreamError::GraphValidation(format!(
147                "{} is not an outlet",
148                outlet_record.name
149            )));
150        }
151        if inlet_record.kind != PortKind::Inlet {
152            return Err(StreamError::GraphValidation(format!(
153                "{} is not an inlet",
154                inlet_record.name
155            )));
156        }
157        if outlet_record.type_id != inlet_record.type_id {
158            return Err(StreamError::GraphValidation(format!(
159                "cannot connect outlet {} ({}) to inlet {} ({})",
160                outlet_record.name,
161                outlet_record.type_name,
162                inlet_record.name,
163                inlet_record.type_name
164            )));
165        }
166        if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
167            return Err(StreamError::GraphValidation(format!(
168                "outlet {} is already connected",
169                outlet_record.name
170            )));
171        }
172        if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
173            return Err(StreamError::GraphValidation(format!(
174                "inlet {} is already connected",
175                inlet_record.name
176            )));
177        }
178
179        Ok(())
180    }
181
182    fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
183        let mut errors = self.errors;
184        let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
185        let connected_outlets: HashSet<PortId> =
186            self.edges.iter().map(|edge| edge.outlet).collect();
187
188        let open_inlets: HashSet<PortId> = self
189            .ports
190            .iter()
191            .filter_map(|(id, port)| {
192                (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
193            })
194            .collect();
195        let open_outlets: HashSet<PortId> = self
196            .ports
197            .iter()
198            .filter_map(|(id, port)| {
199                (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
200            })
201            .collect();
202
203        let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
204        let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
205
206        for inlet in shape.inlets() {
207            match self.ports.get(&inlet.id()) {
208                Some(port)
209                    if port.kind == PortKind::Inlet
210                        && port.type_id == inlet.type_id()
211                        && port.name.as_ref() == inlet.name() => {}
212                Some(port) if port.kind == PortKind::Inlet => {
213                    errors.push(StreamError::GraphValidation(format!(
214                        "result shape inlet {} does not match registered inlet {} ({})",
215                        inlet.name(),
216                        port.name,
217                        port.type_name
218                    )));
219                }
220                Some(port) => errors.push(StreamError::GraphValidation(format!(
221                    "result shape references non-inlet port {}",
222                    port.name
223                ))),
224                None => errors.push(StreamError::GraphValidation(format!(
225                    "result shape references unknown inlet {}",
226                    inlet.name()
227                ))),
228            }
229        }
230        for outlet in shape.outlets() {
231            match self.ports.get(&outlet.id()) {
232                Some(port)
233                    if port.kind == PortKind::Outlet
234                        && port.type_id == outlet.type_id()
235                        && port.name.as_ref() == outlet.name() => {}
236                Some(port) if port.kind == PortKind::Outlet => {
237                    errors.push(StreamError::GraphValidation(format!(
238                        "result shape outlet {} does not match registered outlet {} ({})",
239                        outlet.name(),
240                        port.name,
241                        port.type_name
242                    )));
243                }
244                Some(port) => errors.push(StreamError::GraphValidation(format!(
245                    "result shape references non-outlet port {}",
246                    port.name
247                ))),
248                None => errors.push(StreamError::GraphValidation(format!(
249                    "result shape references unknown outlet {}",
250                    outlet.name()
251                ))),
252            }
253        }
254
255        if open_inlets != result_inlets {
256            errors.push(StreamError::GraphValidation(format!(
257                "result shape inlets do not match open inlets: open={:?}, result={:?}",
258                describe_ports(&self.ports, &open_inlets),
259                describe_ports(&self.ports, &result_inlets)
260            )));
261        }
262        if open_outlets != result_outlets {
263            errors.push(StreamError::GraphValidation(format!(
264                "result shape outlets do not match open outlets: open={:?}, result={:?}",
265                describe_ports(&self.ports, &open_outlets),
266                describe_ports(&self.ports, &result_outlets)
267            )));
268        }
269
270        if !errors.is_empty() {
271            return Err(StreamError::GraphValidation(
272                errors
273                    .into_iter()
274                    .map(|error| error.to_string())
275                    .collect::<Vec<_>>()
276                    .join("; "),
277            ));
278        }
279
280        let segments = compute_segments(&self.stages);
281        Ok(GraphBlueprint {
282            shape,
283            stages: self.stages,
284            edges: self.edges,
285            segments,
286            attributes: Attributes::default(),
287        })
288    }
289}
290
291fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
292    let mut names = ids
293        .iter()
294        .map(|id| {
295            ports
296                .get(id)
297                .map(|port| port.name.as_ref().to_owned())
298                .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
299        })
300        .collect::<Vec<_>>();
301    names.sort();
302    names
303}
304
305fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
306    let mut segments = Vec::with_capacity(1);
307    let mut current = Vec::with_capacity(stages.len());
308
309    for (index, stage) in stages.iter().enumerate() {
310        if stage.spec.async_boundary && !current.is_empty() {
311            segments.push(FusedSegment {
312                stage_indices: std::mem::take(&mut current),
313            });
314        }
315        current.push(index);
316        if stage.spec.async_boundary {
317            segments.push(FusedSegment {
318                stage_indices: std::mem::take(&mut current),
319            });
320        }
321    }
322
323    if !current.is_empty() {
324        segments.push(FusedSegment {
325            stage_indices: current,
326        });
327    }
328
329    segments
330}
331
332pub struct GraphDsl;
333
334impl GraphDsl {
335    pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
336    where
337        S: Shape,
338        F: FnOnce(&mut GraphBuilder) -> S,
339    {
340        let mut builder = GraphBuilder::default();
341        let shape = build(&mut builder);
342        builder.finish(shape)
343    }
344
345    pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
346    where
347        S: Shape,
348        F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
349    {
350        let mut builder = GraphBuilder::default();
351        let shape = build(&mut builder)?;
352        builder.finish(shape)
353    }
354
355    pub fn partial<S, F>(build: F) -> PartialGraph<S>
356    where
357        S: Shape,
358        F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
359    {
360        PartialGraph {
361            build: Arc::new(build),
362            attributes: Attributes::default(),
363        }
364    }
365}
366
367pub trait Graph {
368    type Shape: Shape;
369
370    fn shape(&self) -> Self::Shape;
371}
372
373#[derive(Clone, Debug)]
374pub struct FusedSegment {
375    stage_indices: Vec<usize>,
376}
377
378impl FusedSegment {
379    #[must_use]
380    pub fn stage_indices(&self) -> &[usize] {
381        &self.stage_indices
382    }
383}
384
385pub struct GraphBlueprint<S: Shape> {
386    pub(super) shape: S,
387    pub(super) stages: Vec<StageRecord>,
388    pub(super) edges: Vec<Edge>,
389    pub(super) segments: Vec<FusedSegment>,
390    pub(super) attributes: Attributes,
391}
392
393impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
394    fn clone(&self) -> Self {
395        Self {
396            shape: self.shape.clone(),
397            stages: self.stages.clone(),
398            edges: self.edges.clone(),
399            segments: self.segments.clone(),
400            attributes: self.attributes.clone(),
401        }
402    }
403}
404
405impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
406    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407        f.debug_struct("GraphBlueprint")
408            .field("shape", &self.shape)
409            .field("stages", &self.stages)
410            .field("edges", &self.edges)
411            .field("segments", &self.segments)
412            .field("attributes", &self.attributes)
413            .finish()
414    }
415}
416
417impl<S: Shape> GraphBlueprint<S> {
418    #[must_use]
419    pub fn shape(&self) -> S {
420        self.shape.clone()
421    }
422
423    #[must_use]
424    pub fn stage_count(&self) -> usize {
425        self.stages.len()
426    }
427
428    #[must_use]
429    pub fn edge_count(&self) -> usize {
430        self.edges.len()
431    }
432
433    #[must_use]
434    pub fn segments(&self) -> &[FusedSegment] {
435        &self.segments
436    }
437
438    #[must_use]
439    pub fn attributes(&self) -> &Attributes {
440        &self.attributes
441    }
442
443    #[must_use]
444    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
445        self.attributes = attributes;
446        self
447    }
448
449    #[must_use]
450    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
451        self.attributes = self.attributes.and(attributes);
452        self
453    }
454
455    #[must_use]
456    pub fn named(self, name: impl Into<String>) -> Self {
457        self.add_attributes(Attributes::named(name))
458    }
459}
460
461impl<S: Shape> Graph for GraphBlueprint<S> {
462    type Shape = S;
463
464    fn shape(&self) -> Self::Shape {
465        self.shape()
466    }
467}
468
469#[derive(Clone)]
470pub struct PartialGraph<S: Shape> {
471    build: Arc<PartialGraphBuilder<S>>,
472    attributes: Attributes,
473}
474
475impl<S: Shape> PartialGraph<S> {
476    pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
477        (self.build)(builder)
478    }
479
480    #[must_use]
481    pub fn attributes(&self) -> &Attributes {
482        &self.attributes
483    }
484
485    #[must_use]
486    pub fn with_attributes(mut self, attributes: Attributes) -> Self {
487        self.attributes = attributes;
488        self
489    }
490
491    #[must_use]
492    pub fn add_attributes(mut self, attributes: Attributes) -> Self {
493        self.attributes = self.attributes.and(attributes);
494        self
495    }
496
497    #[must_use]
498    pub fn named(self, name: impl Into<String>) -> Self {
499        self.add_attributes(Attributes::named(name))
500    }
501}
502
503impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
504    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505        f.debug_struct("PartialGraph")
506            .field("attributes", &self.attributes)
507            .finish_non_exhaustive()
508    }
509}
510
511pub type ImportedGraph<S> = PartialGraph<S>;
512
513#[derive(Clone, Copy, Debug, PartialEq, Eq)]
514pub struct FusedExecutionConfig {
515    pub event_limit: usize,
516}
517
518impl Default for FusedExecutionConfig {
519    fn default() -> Self {
520        Self {
521            event_limit: 100_000_000,
522        }
523    }
524}
525
526/// Execution settings for the current graph async-boundary benchmark path.
527///
528/// This path validates a typed-linear graph and uses Ractor-backed async
529/// regions with bounded handoff queues to measure real boundary crossing cost.
530#[derive(Clone, Copy, Debug, PartialEq, Eq)]
531pub struct AsyncBoundaryExecutionConfig {
532    pub fused: FusedExecutionConfig,
533    pub buffer_size: usize,
534}
535
536impl Default for AsyncBoundaryExecutionConfig {
537    fn default() -> Self {
538        Self {
539            fused: FusedExecutionConfig::default(),
540            buffer_size: 16,
541        }
542    }
543}
544
545#[derive(Clone, Debug, PartialEq, Eq)]
546pub struct FusedExecutionReport<T> {
547    pub output: Vec<T>,
548    pub events: usize,
549    pub async_boundary_crossings: usize,
550}
551
552#[derive(Clone, Debug, PartialEq, Eq)]
553pub struct FusedTerminalReport<T> {
554    pub result: T,
555    pub events: usize,
556    pub async_boundary_crossings: usize,
557}