1use super::*;
2use crate::Attributes;
3
4type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
5
6#[derive(Clone, Debug)]
7struct PortRecord {
8 kind: PortKind,
9 type_id: TypeId,
10 type_name: &'static str,
11 name: Arc<str>,
12}
13
14#[derive(Clone, Debug)]
15pub(super) struct Edge {
16 pub(super) outlet: PortId,
17 pub(super) inlet: PortId,
18}
19
20#[derive(Clone)]
21pub(super) struct StageRecord {
22 pub(super) spec: StageSpec,
23 pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
24}
25
26impl std::fmt::Debug for StageRecord {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("StageRecord")
29 .field("spec", &self.spec)
30 .field("has_logic", &self.logic_factory.is_some())
31 .finish()
32 }
33}
34
35#[derive(Debug, Default)]
36pub struct GraphBuilder {
37 allocator: PortAllocator,
38 ports: HashMap<PortId, PortRecord>,
39 stages: Vec<StageRecord>,
40 edges: Vec<Edge>,
41 errors: Vec<StreamError>,
42}
43
44impl GraphBuilder {
45 #[must_use]
46 pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
47 self.add_with_attributes(stage, Attributes::default())
48 }
49
50 #[must_use]
51 pub fn add_with_attributes<G: GraphStage>(
52 &mut self,
53 stage: G,
54 attributes: Attributes,
55 ) -> G::Shape {
56 let shape = stage.allocate_shape(&mut self.allocator);
57 let inlets = shape.inlets();
58 let outlets = shape.outlets();
59 self.ports.reserve(inlets.len() + outlets.len());
60
61 for inlet in &inlets {
62 self.register_inlet(inlet);
63 }
64 for outlet in &outlets {
65 self.register_outlet(outlet);
66 }
67 let spec = stage
68 .stage_spec_with_ports(&shape, inlets, outlets)
69 .add_attributes(attributes);
70 let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
71 let shape_clone = shape.clone();
72 Some(Arc::new(move || stage.create_logic(&shape_clone))
73 as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
74 } else {
75 None
76 };
77 self.stages.push(StageRecord {
78 spec,
79 logic_factory,
80 });
81 shape
82 }
83
84 #[must_use]
85 pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
86 self.add_with_attributes(stage, Attributes::named(name))
87 }
88
89 pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
90 self.connect_any(outlet.erase(), inlet.erase())
91 }
92
93 pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
94 match self.validate_connection(&outlet, &inlet) {
95 Ok(()) => {
96 self.edges.push(Edge {
97 outlet: outlet.id(),
98 inlet: inlet.id(),
99 });
100 Ok(())
101 }
102 Err(error) => {
103 self.errors.push(error.clone());
104 Err(error)
105 }
106 }
107 }
108
109 pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
110 graph.build(self)
111 }
112
113 fn register_inlet(&mut self, inlet: &AnyInlet) {
114 self.ports.insert(
115 inlet.id(),
116 PortRecord {
117 kind: PortKind::Inlet,
118 type_id: inlet.type_id(),
119 type_name: inlet.type_name(),
120 name: Arc::clone(&inlet.name),
121 },
122 );
123 }
124
125 fn register_outlet(&mut self, outlet: &AnyOutlet) {
126 self.ports.insert(
127 outlet.id(),
128 PortRecord {
129 kind: PortKind::Outlet,
130 type_id: outlet.type_id(),
131 type_name: outlet.type_name(),
132 name: Arc::clone(&outlet.name),
133 },
134 );
135 }
136
137 fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
138 let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
139 StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
140 })?;
141 let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
142 StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
143 })?;
144
145 if outlet_record.kind != PortKind::Outlet {
146 return Err(StreamError::GraphValidation(format!(
147 "{} is not an outlet",
148 outlet_record.name
149 )));
150 }
151 if inlet_record.kind != PortKind::Inlet {
152 return Err(StreamError::GraphValidation(format!(
153 "{} is not an inlet",
154 inlet_record.name
155 )));
156 }
157 if outlet_record.type_id != inlet_record.type_id {
158 return Err(StreamError::GraphValidation(format!(
159 "cannot connect outlet {} ({}) to inlet {} ({})",
160 outlet_record.name,
161 outlet_record.type_name,
162 inlet_record.name,
163 inlet_record.type_name
164 )));
165 }
166 if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
167 return Err(StreamError::GraphValidation(format!(
168 "outlet {} is already connected",
169 outlet_record.name
170 )));
171 }
172 if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
173 return Err(StreamError::GraphValidation(format!(
174 "inlet {} is already connected",
175 inlet_record.name
176 )));
177 }
178
179 Ok(())
180 }
181
182 fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
183 let mut errors = self.errors;
184 let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
185 let connected_outlets: HashSet<PortId> =
186 self.edges.iter().map(|edge| edge.outlet).collect();
187
188 let open_inlets: HashSet<PortId> = self
189 .ports
190 .iter()
191 .filter_map(|(id, port)| {
192 (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
193 })
194 .collect();
195 let open_outlets: HashSet<PortId> = self
196 .ports
197 .iter()
198 .filter_map(|(id, port)| {
199 (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
200 })
201 .collect();
202
203 let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
204 let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
205
206 for inlet in shape.inlets() {
207 match self.ports.get(&inlet.id()) {
208 Some(port)
209 if port.kind == PortKind::Inlet
210 && port.type_id == inlet.type_id()
211 && port.name.as_ref() == inlet.name() => {}
212 Some(port) if port.kind == PortKind::Inlet => {
213 errors.push(StreamError::GraphValidation(format!(
214 "result shape inlet {} does not match registered inlet {} ({})",
215 inlet.name(),
216 port.name,
217 port.type_name
218 )));
219 }
220 Some(port) => errors.push(StreamError::GraphValidation(format!(
221 "result shape references non-inlet port {}",
222 port.name
223 ))),
224 None => errors.push(StreamError::GraphValidation(format!(
225 "result shape references unknown inlet {}",
226 inlet.name()
227 ))),
228 }
229 }
230 for outlet in shape.outlets() {
231 match self.ports.get(&outlet.id()) {
232 Some(port)
233 if port.kind == PortKind::Outlet
234 && port.type_id == outlet.type_id()
235 && port.name.as_ref() == outlet.name() => {}
236 Some(port) if port.kind == PortKind::Outlet => {
237 errors.push(StreamError::GraphValidation(format!(
238 "result shape outlet {} does not match registered outlet {} ({})",
239 outlet.name(),
240 port.name,
241 port.type_name
242 )));
243 }
244 Some(port) => errors.push(StreamError::GraphValidation(format!(
245 "result shape references non-outlet port {}",
246 port.name
247 ))),
248 None => errors.push(StreamError::GraphValidation(format!(
249 "result shape references unknown outlet {}",
250 outlet.name()
251 ))),
252 }
253 }
254
255 if open_inlets != result_inlets {
256 errors.push(StreamError::GraphValidation(format!(
257 "result shape inlets do not match open inlets: open={:?}, result={:?}",
258 describe_ports(&self.ports, &open_inlets),
259 describe_ports(&self.ports, &result_inlets)
260 )));
261 }
262 if open_outlets != result_outlets {
263 errors.push(StreamError::GraphValidation(format!(
264 "result shape outlets do not match open outlets: open={:?}, result={:?}",
265 describe_ports(&self.ports, &open_outlets),
266 describe_ports(&self.ports, &result_outlets)
267 )));
268 }
269
270 if !errors.is_empty() {
271 return Err(StreamError::GraphValidation(
272 errors
273 .into_iter()
274 .map(|error| error.to_string())
275 .collect::<Vec<_>>()
276 .join("; "),
277 ));
278 }
279
280 let segments = compute_segments(&self.stages);
281 Ok(GraphBlueprint {
282 shape,
283 stages: self.stages,
284 edges: self.edges,
285 segments,
286 attributes: Attributes::default(),
287 })
288 }
289}
290
291fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
292 let mut names = ids
293 .iter()
294 .map(|id| {
295 ports
296 .get(id)
297 .map(|port| port.name.as_ref().to_owned())
298 .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
299 })
300 .collect::<Vec<_>>();
301 names.sort();
302 names
303}
304
305fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
306 let mut segments = Vec::with_capacity(1);
307 let mut current = Vec::with_capacity(stages.len());
308
309 for (index, stage) in stages.iter().enumerate() {
310 if stage.spec.async_boundary && !current.is_empty() {
311 segments.push(FusedSegment {
312 stage_indices: std::mem::take(&mut current),
313 });
314 }
315 current.push(index);
316 if stage.spec.async_boundary {
317 segments.push(FusedSegment {
318 stage_indices: std::mem::take(&mut current),
319 });
320 }
321 }
322
323 if !current.is_empty() {
324 segments.push(FusedSegment {
325 stage_indices: current,
326 });
327 }
328
329 segments
330}
331
332pub struct GraphDsl;
333
334impl GraphDsl {
335 pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
336 where
337 S: Shape,
338 F: FnOnce(&mut GraphBuilder) -> S,
339 {
340 let mut builder = GraphBuilder::default();
341 let shape = build(&mut builder);
342 builder.finish(shape)
343 }
344
345 pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
346 where
347 S: Shape,
348 F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
349 {
350 let mut builder = GraphBuilder::default();
351 let shape = build(&mut builder)?;
352 builder.finish(shape)
353 }
354
355 pub fn partial<S, F>(build: F) -> PartialGraph<S>
356 where
357 S: Shape,
358 F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
359 {
360 PartialGraph {
361 build: Arc::new(build),
362 attributes: Attributes::default(),
363 }
364 }
365}
366
367pub trait Graph {
368 type Shape: Shape;
369
370 fn shape(&self) -> Self::Shape;
371}
372
373#[derive(Clone, Debug)]
374pub struct FusedSegment {
375 stage_indices: Vec<usize>,
376}
377
378impl FusedSegment {
379 #[must_use]
380 pub fn stage_indices(&self) -> &[usize] {
381 &self.stage_indices
382 }
383}
384
385pub struct GraphBlueprint<S: Shape> {
386 pub(super) shape: S,
387 pub(super) stages: Vec<StageRecord>,
388 pub(super) edges: Vec<Edge>,
389 pub(super) segments: Vec<FusedSegment>,
390 pub(super) attributes: Attributes,
391}
392
393impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
394 fn clone(&self) -> Self {
395 Self {
396 shape: self.shape.clone(),
397 stages: self.stages.clone(),
398 edges: self.edges.clone(),
399 segments: self.segments.clone(),
400 attributes: self.attributes.clone(),
401 }
402 }
403}
404
405impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
406 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
407 f.debug_struct("GraphBlueprint")
408 .field("shape", &self.shape)
409 .field("stages", &self.stages)
410 .field("edges", &self.edges)
411 .field("segments", &self.segments)
412 .field("attributes", &self.attributes)
413 .finish()
414 }
415}
416
417impl<S: Shape> GraphBlueprint<S> {
418 #[must_use]
419 pub fn shape(&self) -> S {
420 self.shape.clone()
421 }
422
423 #[must_use]
424 pub fn stage_count(&self) -> usize {
425 self.stages.len()
426 }
427
428 #[must_use]
429 pub fn edge_count(&self) -> usize {
430 self.edges.len()
431 }
432
433 #[must_use]
434 pub fn segments(&self) -> &[FusedSegment] {
435 &self.segments
436 }
437
438 #[must_use]
439 pub fn attributes(&self) -> &Attributes {
440 &self.attributes
441 }
442
443 #[must_use]
444 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
445 self.attributes = attributes;
446 self
447 }
448
449 #[must_use]
450 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
451 self.attributes = self.attributes.and(attributes);
452 self
453 }
454
455 #[must_use]
456 pub fn named(self, name: impl Into<String>) -> Self {
457 self.add_attributes(Attributes::named(name))
458 }
459}
460
461impl<S: Shape> Graph for GraphBlueprint<S> {
462 type Shape = S;
463
464 fn shape(&self) -> Self::Shape {
465 self.shape()
466 }
467}
468
469#[derive(Clone)]
470pub struct PartialGraph<S: Shape> {
471 build: Arc<PartialGraphBuilder<S>>,
472 attributes: Attributes,
473}
474
475impl<S: Shape> PartialGraph<S> {
476 pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
477 (self.build)(builder)
478 }
479
480 #[must_use]
481 pub fn attributes(&self) -> &Attributes {
482 &self.attributes
483 }
484
485 #[must_use]
486 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
487 self.attributes = attributes;
488 self
489 }
490
491 #[must_use]
492 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
493 self.attributes = self.attributes.and(attributes);
494 self
495 }
496
497 #[must_use]
498 pub fn named(self, name: impl Into<String>) -> Self {
499 self.add_attributes(Attributes::named(name))
500 }
501}
502
503impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
505 f.debug_struct("PartialGraph")
506 .field("attributes", &self.attributes)
507 .finish_non_exhaustive()
508 }
509}
510
511pub type ImportedGraph<S> = PartialGraph<S>;
512
513#[derive(Clone, Copy, Debug, PartialEq, Eq)]
514pub struct FusedExecutionConfig {
515 pub event_limit: usize,
516}
517
518impl Default for FusedExecutionConfig {
519 fn default() -> Self {
520 Self {
521 event_limit: 100_000_000,
522 }
523 }
524}
525
526#[derive(Clone, Copy, Debug, PartialEq, Eq)]
531pub struct AsyncBoundaryExecutionConfig {
532 pub fused: FusedExecutionConfig,
533 pub buffer_size: usize,
534}
535
536impl Default for AsyncBoundaryExecutionConfig {
537 fn default() -> Self {
538 Self {
539 fused: FusedExecutionConfig::default(),
540 buffer_size: 16,
541 }
542 }
543}
544
545#[derive(Clone, Debug, PartialEq, Eq)]
546pub struct FusedExecutionReport<T> {
547 pub output: Vec<T>,
548 pub events: usize,
549 pub async_boundary_crossings: usize,
550}
551
552#[derive(Clone, Debug, PartialEq, Eq)]
553pub struct FusedTerminalReport<T> {
554 pub result: T,
555 pub events: usize,
556 pub async_boundary_crossings: usize,
557}