1use super::*;
2use crate::Attributes;
3
4type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
5
6#[derive(Clone, Debug)]
7struct PortRecord {
8 kind: PortKind,
9 type_id: TypeId,
10 type_name: &'static str,
11 name: Arc<str>,
12}
13
14#[derive(Clone, Debug)]
15pub(super) struct Edge {
16 pub(super) outlet: PortId,
17 pub(super) inlet: PortId,
18}
19
20#[derive(Clone)]
21pub(super) struct StageRecord {
22 pub(super) spec: StageSpec,
23 pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
24}
25
26impl std::fmt::Debug for StageRecord {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("StageRecord")
29 .field("spec", &self.spec)
30 .field("has_logic", &self.logic_factory.is_some())
31 .finish()
32 }
33}
34
35#[derive(Debug, Default)]
36pub struct GraphBuilder {
37 allocator: PortAllocator,
38 ports: HashMap<PortId, PortRecord>,
39 stages: Vec<StageRecord>,
40 edges: Vec<Edge>,
41 errors: Vec<StreamError>,
42}
43
44impl GraphBuilder {
45 #[must_use]
46 pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
47 self.add_with_attributes(stage, Attributes::default())
48 }
49
50 #[must_use]
51 pub fn add_with_attributes<G: GraphStage>(
52 &mut self,
53 stage: G,
54 attributes: Attributes,
55 ) -> G::Shape {
56 let shape = stage.allocate_shape(&mut self.allocator);
57 let inlets = shape.inlets();
58 let outlets = shape.outlets();
59 self.ports.reserve(inlets.len() + outlets.len());
60
61 for inlet in &inlets {
62 self.register_inlet(inlet);
63 }
64 for outlet in &outlets {
65 self.register_outlet(outlet);
66 }
67 let spec = stage
68 .stage_spec_with_ports(&shape, inlets, outlets)
69 .add_attributes(attributes);
70 let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
71 let shape_clone = shape.clone();
72 Some(Arc::new(move || stage.create_logic(&shape_clone))
73 as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
74 } else {
75 None
76 };
77 self.stages.push(StageRecord {
78 spec,
79 logic_factory,
80 });
81 shape
82 }
83
84 #[must_use]
85 pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
86 self.add_with_attributes(stage, Attributes::named(name))
87 }
88
89 pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
90 self.connect_any(outlet.erase(), inlet.erase())
91 }
92
93 pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
94 match self.validate_connection(&outlet, &inlet) {
95 Ok(()) => {
96 self.edges.push(Edge {
97 outlet: outlet.id(),
98 inlet: inlet.id(),
99 });
100 Ok(())
101 }
102 Err(error) => {
103 self.errors.push(error.clone());
104 Err(error)
105 }
106 }
107 }
108
109 pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
110 graph.build(self)
111 }
112
113 fn register_inlet(&mut self, inlet: &AnyInlet) {
114 self.ports.insert(
115 inlet.id(),
116 PortRecord {
117 kind: PortKind::Inlet,
118 type_id: inlet.type_id(),
119 type_name: inlet.type_name(),
120 name: Arc::clone(&inlet.name),
121 },
122 );
123 }
124
125 fn register_outlet(&mut self, outlet: &AnyOutlet) {
126 self.ports.insert(
127 outlet.id(),
128 PortRecord {
129 kind: PortKind::Outlet,
130 type_id: outlet.type_id(),
131 type_name: outlet.type_name(),
132 name: Arc::clone(&outlet.name),
133 },
134 );
135 }
136
137 fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
138 let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
139 StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
140 })?;
141 let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
142 StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
143 })?;
144
145 if outlet_record.kind != PortKind::Outlet {
146 return Err(StreamError::GraphValidation(format!(
147 "{} is not an outlet",
148 outlet_record.name
149 )));
150 }
151 if inlet_record.kind != PortKind::Inlet {
152 return Err(StreamError::GraphValidation(format!(
153 "{} is not an inlet",
154 inlet_record.name
155 )));
156 }
157 if outlet_record.type_id != inlet_record.type_id {
158 return Err(StreamError::GraphValidation(format!(
159 "cannot connect outlet {} ({}) to inlet {} ({})",
160 outlet_record.name,
161 outlet_record.type_name,
162 inlet_record.name,
163 inlet_record.type_name
164 )));
165 }
166 if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
167 return Err(StreamError::GraphValidation(format!(
168 "outlet {} is already connected",
169 outlet_record.name
170 )));
171 }
172 if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
173 return Err(StreamError::GraphValidation(format!(
174 "inlet {} is already connected",
175 inlet_record.name
176 )));
177 }
178
179 Ok(())
180 }
181
182 fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
183 let mut errors = self.errors;
184 let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
185 let connected_outlets: HashSet<PortId> =
186 self.edges.iter().map(|edge| edge.outlet).collect();
187
188 let open_inlets: HashSet<PortId> = self
189 .ports
190 .iter()
191 .filter_map(|(id, port)| {
192 (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
193 })
194 .collect();
195 let open_outlets: HashSet<PortId> = self
196 .ports
197 .iter()
198 .filter_map(|(id, port)| {
199 (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
200 })
201 .collect();
202
203 let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
204 let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
205
206 for inlet in shape.inlets() {
207 match self.ports.get(&inlet.id()) {
208 Some(port)
209 if port.kind == PortKind::Inlet
210 && port.type_id == inlet.type_id()
211 && port.name.as_ref() == inlet.name() => {}
212 Some(port) if port.kind == PortKind::Inlet => {
213 errors.push(StreamError::GraphValidation(format!(
214 "result shape inlet {} does not match registered inlet {} ({})",
215 inlet.name(),
216 port.name,
217 port.type_name
218 )));
219 }
220 Some(port) => errors.push(StreamError::GraphValidation(format!(
221 "result shape references non-inlet port {}",
222 port.name
223 ))),
224 None => errors.push(StreamError::GraphValidation(format!(
225 "result shape references unknown inlet {}",
226 inlet.name()
227 ))),
228 }
229 }
230 for outlet in shape.outlets() {
231 match self.ports.get(&outlet.id()) {
232 Some(port)
233 if port.kind == PortKind::Outlet
234 && port.type_id == outlet.type_id()
235 && port.name.as_ref() == outlet.name() => {}
236 Some(port) if port.kind == PortKind::Outlet => {
237 errors.push(StreamError::GraphValidation(format!(
238 "result shape outlet {} does not match registered outlet {} ({})",
239 outlet.name(),
240 port.name,
241 port.type_name
242 )));
243 }
244 Some(port) => errors.push(StreamError::GraphValidation(format!(
245 "result shape references non-outlet port {}",
246 port.name
247 ))),
248 None => errors.push(StreamError::GraphValidation(format!(
249 "result shape references unknown outlet {}",
250 outlet.name()
251 ))),
252 }
253 }
254
255 if open_inlets != result_inlets {
256 errors.push(StreamError::GraphValidation(format!(
257 "result shape inlets do not match open inlets: open={:?}, result={:?}",
258 describe_ports(&self.ports, &open_inlets),
259 describe_ports(&self.ports, &result_inlets)
260 )));
261 }
262 if open_outlets != result_outlets {
263 errors.push(StreamError::GraphValidation(format!(
264 "result shape outlets do not match open outlets: open={:?}, result={:?}",
265 describe_ports(&self.ports, &open_outlets),
266 describe_ports(&self.ports, &result_outlets)
267 )));
268 }
269
270 if graph_has_cycle(&self.stages, &self.edges) {
271 errors.push(StreamError::GraphValidation(
272 "graph contains a cycle; Datum still rejects cyclic fused graphs until WP-16 adds a demand-aware graph interpreter".into(),
273 ));
274 }
275
276 if !errors.is_empty() {
277 return Err(StreamError::GraphValidation(
278 errors
279 .into_iter()
280 .map(|error| error.to_string())
281 .collect::<Vec<_>>()
282 .join("; "),
283 ));
284 }
285
286 let segments = compute_segments(&self.stages);
287 Ok(GraphBlueprint {
288 shape,
289 stages: self.stages,
290 edges: self.edges,
291 segments,
292 attributes: Attributes::default(),
293 })
294 }
295}
296
297fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
298 let mut names = ids
299 .iter()
300 .map(|id| {
301 ports
302 .get(id)
303 .map(|port| port.name.as_ref().to_owned())
304 .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
305 })
306 .collect::<Vec<_>>();
307 names.sort();
308 names
309}
310
311fn graph_has_cycle(stages: &[StageRecord], edges: &[Edge]) -> bool {
318 let mut stage_of_inlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
319 let mut stage_of_outlet: HashMap<PortId, usize> = HashMap::with_capacity(stages.len());
320 for (index, stage) in stages.iter().enumerate() {
321 for inlet in &stage.spec.inlets {
322 stage_of_inlet.insert(inlet.id(), index);
323 }
324 for outlet in &stage.spec.outlets {
325 stage_of_outlet.insert(outlet.id(), index);
326 }
327 }
328
329 const NO_SUCCESSOR: usize = usize::MAX;
330 let mut first_successor = vec![NO_SUCCESSOR; stages.len()];
331 let mut successor_to = Vec::with_capacity(edges.len());
332 let mut next_successor = Vec::with_capacity(edges.len());
333 let mut indegree: Vec<usize> = vec![0; stages.len()];
334 for edge in edges {
335 if let (Some(&from), Some(&to)) = (
336 stage_of_outlet.get(&edge.outlet),
337 stage_of_inlet.get(&edge.inlet),
338 ) {
339 successor_to.push(to);
340 next_successor.push(first_successor[from]);
341 first_successor[from] = successor_to.len() - 1;
342 indegree[to] += 1;
343 }
344 }
345
346 let mut stack: Vec<usize> = (0..stages.len()).filter(|&i| indegree[i] == 0).collect();
349 let mut visited = 0_usize;
350 while let Some(stage) = stack.pop() {
351 visited += 1;
352 let mut successor = first_successor[stage];
353 while successor != NO_SUCCESSOR {
354 let next = successor_to[successor];
355 indegree[next] -= 1;
356 if indegree[next] == 0 {
357 stack.push(next);
358 }
359 successor = next_successor[successor];
360 }
361 }
362
363 visited != stages.len()
364}
365
366fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
367 let mut segments = Vec::with_capacity(1);
368 let mut current = Vec::with_capacity(stages.len());
369
370 for (index, stage) in stages.iter().enumerate() {
371 if stage.spec.async_boundary && !current.is_empty() {
372 segments.push(FusedSegment {
373 stage_indices: std::mem::take(&mut current),
374 });
375 }
376 current.push(index);
377 if stage.spec.async_boundary {
378 segments.push(FusedSegment {
379 stage_indices: std::mem::take(&mut current),
380 });
381 }
382 }
383
384 if !current.is_empty() {
385 segments.push(FusedSegment {
386 stage_indices: current,
387 });
388 }
389
390 segments
391}
392
393pub struct GraphDsl;
394
395impl GraphDsl {
396 pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
397 where
398 S: Shape,
399 F: FnOnce(&mut GraphBuilder) -> S,
400 {
401 let mut builder = GraphBuilder::default();
402 let shape = build(&mut builder);
403 builder.finish(shape)
404 }
405
406 pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
407 where
408 S: Shape,
409 F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
410 {
411 let mut builder = GraphBuilder::default();
412 let shape = build(&mut builder)?;
413 builder.finish(shape)
414 }
415
416 pub fn partial<S, F>(build: F) -> PartialGraph<S>
417 where
418 S: Shape,
419 F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
420 {
421 PartialGraph {
422 build: Arc::new(build),
423 attributes: Attributes::default(),
424 }
425 }
426}
427
428pub trait Graph {
429 type Shape: Shape;
430
431 fn shape(&self) -> Self::Shape;
432}
433
434#[derive(Clone, Debug)]
435pub struct FusedSegment {
436 stage_indices: Vec<usize>,
437}
438
439impl FusedSegment {
440 #[must_use]
441 pub fn stage_indices(&self) -> &[usize] {
442 &self.stage_indices
443 }
444}
445
446pub struct GraphBlueprint<S: Shape> {
447 pub(super) shape: S,
448 pub(super) stages: Vec<StageRecord>,
449 pub(super) edges: Vec<Edge>,
450 pub(super) segments: Vec<FusedSegment>,
451 pub(super) attributes: Attributes,
452}
453
454impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
455 fn clone(&self) -> Self {
456 Self {
457 shape: self.shape.clone(),
458 stages: self.stages.clone(),
459 edges: self.edges.clone(),
460 segments: self.segments.clone(),
461 attributes: self.attributes.clone(),
462 }
463 }
464}
465
466impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
467 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
468 f.debug_struct("GraphBlueprint")
469 .field("shape", &self.shape)
470 .field("stages", &self.stages)
471 .field("edges", &self.edges)
472 .field("segments", &self.segments)
473 .field("attributes", &self.attributes)
474 .finish()
475 }
476}
477
478impl<S: Shape> GraphBlueprint<S> {
479 #[must_use]
480 pub fn shape(&self) -> S {
481 self.shape.clone()
482 }
483
484 #[must_use]
485 pub fn stage_count(&self) -> usize {
486 self.stages.len()
487 }
488
489 #[must_use]
490 pub fn edge_count(&self) -> usize {
491 self.edges.len()
492 }
493
494 #[must_use]
495 pub fn segments(&self) -> &[FusedSegment] {
496 &self.segments
497 }
498
499 #[must_use]
500 pub fn attributes(&self) -> &Attributes {
501 &self.attributes
502 }
503
504 #[must_use]
505 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
506 self.attributes = attributes;
507 self
508 }
509
510 #[must_use]
511 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
512 self.attributes = self.attributes.and(attributes);
513 self
514 }
515
516 #[must_use]
517 pub fn named(self, name: impl Into<String>) -> Self {
518 self.add_attributes(Attributes::named(name))
519 }
520}
521
522impl<S: Shape> Graph for GraphBlueprint<S> {
523 type Shape = S;
524
525 fn shape(&self) -> Self::Shape {
526 self.shape()
527 }
528}
529
530#[derive(Clone)]
531pub struct PartialGraph<S: Shape> {
532 build: Arc<PartialGraphBuilder<S>>,
533 attributes: Attributes,
534}
535
536impl<S: Shape> PartialGraph<S> {
537 pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
538 (self.build)(builder)
539 }
540
541 #[must_use]
542 pub fn attributes(&self) -> &Attributes {
543 &self.attributes
544 }
545
546 #[must_use]
547 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
548 self.attributes = attributes;
549 self
550 }
551
552 #[must_use]
553 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
554 self.attributes = self.attributes.and(attributes);
555 self
556 }
557
558 #[must_use]
559 pub fn named(self, name: impl Into<String>) -> Self {
560 self.add_attributes(Attributes::named(name))
561 }
562}
563
564impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
565 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
566 f.debug_struct("PartialGraph")
567 .field("attributes", &self.attributes)
568 .finish_non_exhaustive()
569 }
570}
571
572pub type ImportedGraph<S> = PartialGraph<S>;
573
574#[derive(Clone, Copy, Debug, PartialEq, Eq)]
575pub struct FusedExecutionConfig {
576 pub event_limit: usize,
577}
578
579impl Default for FusedExecutionConfig {
580 fn default() -> Self {
581 Self {
582 event_limit: 100_000_000,
583 }
584 }
585}
586
587#[derive(Clone, Copy, Debug, PartialEq, Eq)]
592pub struct AsyncBoundaryExecutionConfig {
593 pub fused: FusedExecutionConfig,
594 pub buffer_size: usize,
595}
596
597impl Default for AsyncBoundaryExecutionConfig {
598 fn default() -> Self {
599 Self {
600 fused: FusedExecutionConfig::default(),
601 buffer_size: 16,
602 }
603 }
604}
605
606#[derive(Clone, Debug, PartialEq, Eq)]
607pub struct FusedExecutionReport<T> {
608 pub output: Vec<T>,
609 pub events: usize,
610 pub async_boundary_crossings: usize,
611}
612
613#[derive(Clone, Debug, PartialEq, Eq)]
614pub struct FusedTerminalReport<T> {
615 pub result: T,
616 pub events: usize,
617 pub async_boundary_crossings: usize,
618}