1use super::*;
12use crate::Attributes;
13
14type PartialGraphBuilder<S> = dyn Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync;
15
16#[derive(Clone, Debug)]
17struct PortRecord {
18 kind: PortKind,
19 type_id: TypeId,
20 type_name: &'static str,
21 name: Arc<str>,
22}
23
24#[derive(Clone, Debug)]
25pub(super) struct Edge {
26 pub(super) outlet: PortId,
27 pub(super) inlet: PortId,
28}
29
30#[derive(Clone)]
31pub(super) struct StageRecord {
32 pub(super) spec: StageSpec,
33 pub(super) logic_factory: Option<Arc<dyn Fn() -> GraphStageLogic + Send + Sync>>,
34}
35
36impl std::fmt::Debug for StageRecord {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("StageRecord")
39 .field("spec", &self.spec)
40 .field("has_logic", &self.logic_factory.is_some())
41 .finish()
42 }
43}
44
45#[derive(Debug, Default)]
51pub struct GraphBuilder {
52 allocator: PortAllocator,
53 ports: HashMap<PortId, PortRecord>,
54 stages: Vec<StageRecord>,
55 edges: Vec<Edge>,
56 errors: Vec<StreamError>,
57}
58
59impl GraphBuilder {
60 #[must_use]
61 pub fn add<G: GraphStage>(&mut self, stage: G) -> G::Shape {
62 self.add_with_attributes(stage, Attributes::default())
63 }
64
65 #[must_use]
66 pub fn add_with_attributes<G: GraphStage>(
67 &mut self,
68 stage: G,
69 attributes: Attributes,
70 ) -> G::Shape {
71 let shape = stage.allocate_shape(&mut self.allocator);
72 let inlets = shape.inlets();
73 let outlets = shape.outlets();
74 self.ports.reserve(inlets.len() + outlets.len());
75
76 for inlet in &inlets {
77 self.register_inlet(inlet);
78 }
79 for outlet in &outlets {
80 self.register_outlet(outlet);
81 }
82 let spec = stage
83 .stage_spec_with_ports(&shape, inlets, outlets)
84 .add_attributes(attributes);
85 let logic_factory = if matches!(spec.kind, StageKind::Opaque) {
86 let shape_clone = shape.clone();
87 Some(Arc::new(move || stage.create_logic(&shape_clone))
88 as Arc<dyn Fn() -> GraphStageLogic + Send + Sync>)
89 } else {
90 None
91 };
92 self.stages.push(StageRecord {
93 spec,
94 logic_factory,
95 });
96 shape
97 }
98
99 #[must_use]
100 pub fn add_named<G: GraphStage>(&mut self, stage: G, name: impl Into<String>) -> G::Shape {
101 self.add_with_attributes(stage, Attributes::named(name))
102 }
103
104 pub fn connect<T: 'static>(&mut self, outlet: Outlet<T>, inlet: Inlet<T>) -> StreamResult<()> {
105 self.connect_any(outlet.erase(), inlet.erase())
106 }
107
108 pub fn connect_any(&mut self, outlet: AnyOutlet, inlet: AnyInlet) -> StreamResult<()> {
109 match self.connect_any_unrecorded(outlet, inlet) {
110 Ok(()) => Ok(()),
111 Err(error) => self.record_error(error),
112 }
113 }
114
115 pub fn import<S: Shape>(&mut self, graph: &PartialGraph<S>) -> StreamResult<S> {
116 graph.build(self)
117 }
118
119 pub fn wire<W: WireSpec>(&mut self, spec: W) -> &mut Self {
120 if let Err(error) = spec.apply(self) {
121 let _ = self.record_error(error);
122 }
123 self
124 }
125
126 pub fn try_wire<W: WireSpec>(&mut self, spec: W) -> StreamResult<&mut Self> {
127 match spec.apply(self) {
128 Ok(()) => Ok(self),
129 Err(error) => {
130 self.errors.push(error.clone());
131 Err(error)
132 }
133 }
134 }
135
136 pub(super) fn connect_any_unrecorded(
137 &mut self,
138 outlet: AnyOutlet,
139 inlet: AnyInlet,
140 ) -> StreamResult<()> {
141 self.validate_connection(&outlet, &inlet)?;
142 self.edges.push(Edge {
143 outlet: outlet.id(),
144 inlet: inlet.id(),
145 });
146 Ok(())
147 }
148
149 pub(super) fn is_outlet_connected(&self, outlet: &AnyOutlet) -> bool {
150 self.edges.iter().any(|edge| edge.outlet == outlet.id())
151 }
152
153 pub(super) fn is_inlet_connected(&self, inlet: &AnyInlet) -> bool {
154 self.edges.iter().any(|edge| edge.inlet == inlet.id())
155 }
156
157 fn record_error(&mut self, error: StreamError) -> StreamResult<()> {
158 self.errors.push(error.clone());
159 Err(error)
160 }
161
162 fn register_inlet(&mut self, inlet: &AnyInlet) {
163 self.ports.insert(
164 inlet.id(),
165 PortRecord {
166 kind: PortKind::Inlet,
167 type_id: inlet.type_id(),
168 type_name: inlet.type_name(),
169 name: Arc::clone(&inlet.name),
170 },
171 );
172 }
173
174 fn register_outlet(&mut self, outlet: &AnyOutlet) {
175 self.ports.insert(
176 outlet.id(),
177 PortRecord {
178 kind: PortKind::Outlet,
179 type_id: outlet.type_id(),
180 type_name: outlet.type_name(),
181 name: Arc::clone(&outlet.name),
182 },
183 );
184 }
185
186 fn validate_connection(&self, outlet: &AnyOutlet, inlet: &AnyInlet) -> StreamResult<()> {
187 let outlet_record = self.ports.get(&outlet.id()).ok_or_else(|| {
188 StreamError::GraphValidation(format!("unknown outlet {}", outlet.name()))
189 })?;
190 let inlet_record = self.ports.get(&inlet.id()).ok_or_else(|| {
191 StreamError::GraphValidation(format!("unknown inlet {}", inlet.name()))
192 })?;
193
194 if outlet_record.kind != PortKind::Outlet {
195 return Err(StreamError::GraphValidation(format!(
196 "{} is not an outlet",
197 outlet_record.name
198 )));
199 }
200 if inlet_record.kind != PortKind::Inlet {
201 return Err(StreamError::GraphValidation(format!(
202 "{} is not an inlet",
203 inlet_record.name
204 )));
205 }
206 if outlet_record.type_id != inlet_record.type_id {
207 return Err(StreamError::GraphValidation(format!(
208 "cannot connect outlet {} ({}) to inlet {} ({})",
209 outlet_record.name,
210 outlet_record.type_name,
211 inlet_record.name,
212 inlet_record.type_name
213 )));
214 }
215 if self.edges.iter().any(|edge| edge.outlet == outlet.id()) {
216 return Err(StreamError::GraphValidation(format!(
217 "outlet {} is already connected",
218 outlet_record.name
219 )));
220 }
221 if self.edges.iter().any(|edge| edge.inlet == inlet.id()) {
222 return Err(StreamError::GraphValidation(format!(
223 "inlet {} is already connected",
224 inlet_record.name
225 )));
226 }
227
228 Ok(())
229 }
230
231 fn finish<S: Shape>(self, shape: S) -> StreamResult<GraphBlueprint<S>> {
232 let mut errors = self.errors;
233 let connected_inlets: HashSet<PortId> = self.edges.iter().map(|edge| edge.inlet).collect();
234 let connected_outlets: HashSet<PortId> =
235 self.edges.iter().map(|edge| edge.outlet).collect();
236
237 let open_inlets: HashSet<PortId> = self
238 .ports
239 .iter()
240 .filter_map(|(id, port)| {
241 (port.kind == PortKind::Inlet && !connected_inlets.contains(id)).then_some(*id)
242 })
243 .collect();
244 let open_outlets: HashSet<PortId> = self
245 .ports
246 .iter()
247 .filter_map(|(id, port)| {
248 (port.kind == PortKind::Outlet && !connected_outlets.contains(id)).then_some(*id)
249 })
250 .collect();
251
252 let result_inlets: HashSet<PortId> = shape.inlets().iter().map(AnyInlet::id).collect();
253 let result_outlets: HashSet<PortId> = shape.outlets().iter().map(AnyOutlet::id).collect();
254
255 for inlet in shape.inlets() {
256 match self.ports.get(&inlet.id()) {
257 Some(port)
258 if port.kind == PortKind::Inlet
259 && port.type_id == inlet.type_id()
260 && port.name.as_ref() == inlet.name() => {}
261 Some(port) if port.kind == PortKind::Inlet => {
262 errors.push(StreamError::GraphValidation(format!(
263 "result shape inlet {} does not match registered inlet {} ({})",
264 inlet.name(),
265 port.name,
266 port.type_name
267 )));
268 }
269 Some(port) => errors.push(StreamError::GraphValidation(format!(
270 "result shape references non-inlet port {}",
271 port.name
272 ))),
273 None => errors.push(StreamError::GraphValidation(format!(
274 "result shape references unknown inlet {}",
275 inlet.name()
276 ))),
277 }
278 }
279 for outlet in shape.outlets() {
280 match self.ports.get(&outlet.id()) {
281 Some(port)
282 if port.kind == PortKind::Outlet
283 && port.type_id == outlet.type_id()
284 && port.name.as_ref() == outlet.name() => {}
285 Some(port) if port.kind == PortKind::Outlet => {
286 errors.push(StreamError::GraphValidation(format!(
287 "result shape outlet {} does not match registered outlet {} ({})",
288 outlet.name(),
289 port.name,
290 port.type_name
291 )));
292 }
293 Some(port) => errors.push(StreamError::GraphValidation(format!(
294 "result shape references non-outlet port {}",
295 port.name
296 ))),
297 None => errors.push(StreamError::GraphValidation(format!(
298 "result shape references unknown outlet {}",
299 outlet.name()
300 ))),
301 }
302 }
303
304 if open_inlets != result_inlets {
305 errors.push(StreamError::GraphValidation(format!(
306 "result shape inlets do not match open inlets: open={:?}, result={:?}",
307 describe_ports(&self.ports, &open_inlets),
308 describe_ports(&self.ports, &result_inlets)
309 )));
310 }
311 if open_outlets != result_outlets {
312 errors.push(StreamError::GraphValidation(format!(
313 "result shape outlets do not match open outlets: open={:?}, result={:?}",
314 describe_ports(&self.ports, &open_outlets),
315 describe_ports(&self.ports, &result_outlets)
316 )));
317 }
318
319 if !errors.is_empty() {
320 return Err(StreamError::GraphValidation(
321 errors
322 .into_iter()
323 .map(|error| error.to_string())
324 .collect::<Vec<_>>()
325 .join("; "),
326 ));
327 }
328
329 let segments = compute_segments(&self.stages);
330 Ok(GraphBlueprint {
331 shape,
332 stages: self.stages,
333 edges: self.edges,
334 segments,
335 attributes: Attributes::default(),
336 })
337 }
338}
339
340fn describe_ports(ports: &HashMap<PortId, PortRecord>, ids: &HashSet<PortId>) -> Vec<String> {
341 let mut names = ids
342 .iter()
343 .map(|id| {
344 ports
345 .get(id)
346 .map(|port| port.name.as_ref().to_owned())
347 .unwrap_or_else(|| format!("unknown#{}", id.as_usize()))
348 })
349 .collect::<Vec<_>>();
350 names.sort();
351 names
352}
353
354fn compute_segments(stages: &[StageRecord]) -> Vec<FusedSegment> {
355 let mut segments = Vec::with_capacity(1);
356 let mut current = Vec::with_capacity(stages.len());
357
358 for (index, stage) in stages.iter().enumerate() {
359 if stage.spec.async_boundary && !current.is_empty() {
360 segments.push(FusedSegment {
361 stage_indices: std::mem::take(&mut current),
362 });
363 }
364 current.push(index);
365 if stage.spec.async_boundary {
366 segments.push(FusedSegment {
367 stage_indices: std::mem::take(&mut current),
368 });
369 }
370 }
371
372 if !current.is_empty() {
373 segments.push(FusedSegment {
374 stage_indices: current,
375 });
376 }
377
378 segments
379}
380
381pub struct GraphDsl;
386
387impl GraphDsl {
388 pub fn create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
389 where
390 S: Shape,
391 F: FnOnce(&mut GraphBuilder) -> S,
392 {
393 let mut builder = GraphBuilder::default();
394 let shape = build(&mut builder);
395 builder.finish(shape)
396 }
397
398 pub fn try_create<S, F>(build: F) -> StreamResult<GraphBlueprint<S>>
399 where
400 S: Shape,
401 F: FnOnce(&mut GraphBuilder) -> StreamResult<S>,
402 {
403 let mut builder = GraphBuilder::default();
404 let shape = build(&mut builder)?;
405 builder.finish(shape)
406 }
407
408 pub fn partial<S, F>(build: F) -> PartialGraph<S>
409 where
410 S: Shape,
411 F: Fn(&mut GraphBuilder) -> StreamResult<S> + Send + Sync + 'static,
412 {
413 PartialGraph {
414 build: Arc::new(build),
415 attributes: Attributes::default(),
416 }
417 }
418}
419
420pub trait Graph {
422 type Shape: Shape;
423
424 fn shape(&self) -> Self::Shape;
425}
426
427#[derive(Clone, Debug)]
431pub struct FusedSegment {
432 stage_indices: Vec<usize>,
433}
434
435impl FusedSegment {
436 #[must_use]
437 pub fn stage_indices(&self) -> &[usize] {
438 &self.stage_indices
439 }
440}
441
442pub struct GraphBlueprint<S: Shape> {
448 pub(super) shape: S,
449 pub(super) stages: Vec<StageRecord>,
450 pub(super) edges: Vec<Edge>,
451 pub(super) segments: Vec<FusedSegment>,
452 pub(super) attributes: Attributes,
453}
454
455impl<S: Shape + Clone> Clone for GraphBlueprint<S> {
456 fn clone(&self) -> Self {
457 Self {
458 shape: self.shape.clone(),
459 stages: self.stages.clone(),
460 edges: self.edges.clone(),
461 segments: self.segments.clone(),
462 attributes: self.attributes.clone(),
463 }
464 }
465}
466
467impl<S: Shape + fmt::Debug> fmt::Debug for GraphBlueprint<S> {
468 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
469 f.debug_struct("GraphBlueprint")
470 .field("shape", &self.shape)
471 .field("stages", &self.stages)
472 .field("edges", &self.edges)
473 .field("segments", &self.segments)
474 .field("attributes", &self.attributes)
475 .finish()
476 }
477}
478
479impl<S: Shape> GraphBlueprint<S> {
480 #[must_use]
481 pub fn shape(&self) -> S {
482 self.shape.clone()
483 }
484
485 #[must_use]
486 pub fn stage_count(&self) -> usize {
487 self.stages.len()
488 }
489
490 #[must_use]
491 pub fn edge_count(&self) -> usize {
492 self.edges.len()
493 }
494
495 #[must_use]
496 pub fn segments(&self) -> &[FusedSegment] {
497 &self.segments
498 }
499
500 #[must_use]
501 pub fn attributes(&self) -> &Attributes {
502 &self.attributes
503 }
504
505 #[must_use]
506 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
507 self.attributes = attributes;
508 self
509 }
510
511 #[must_use]
512 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
513 self.attributes = self.attributes.and(attributes);
514 self
515 }
516
517 #[must_use]
518 pub fn named(self, name: impl Into<String>) -> Self {
519 self.add_attributes(Attributes::named(name))
520 }
521}
522
523impl<S: Shape> Graph for GraphBlueprint<S> {
524 type Shape = S;
525
526 fn shape(&self) -> Self::Shape {
527 self.shape()
528 }
529}
530
531#[derive(Clone)]
536pub struct PartialGraph<S: Shape> {
537 build: Arc<PartialGraphBuilder<S>>,
538 attributes: Attributes,
539}
540
541impl<S: Shape> PartialGraph<S> {
542 pub fn build(&self, builder: &mut GraphBuilder) -> StreamResult<S> {
543 (self.build)(builder)
544 }
545
546 #[must_use]
547 pub fn attributes(&self) -> &Attributes {
548 &self.attributes
549 }
550
551 #[must_use]
552 pub fn with_attributes(mut self, attributes: Attributes) -> Self {
553 self.attributes = attributes;
554 self
555 }
556
557 #[must_use]
558 pub fn add_attributes(mut self, attributes: Attributes) -> Self {
559 self.attributes = self.attributes.and(attributes);
560 self
561 }
562
563 #[must_use]
564 pub fn named(self, name: impl Into<String>) -> Self {
565 self.add_attributes(Attributes::named(name))
566 }
567}
568
569impl<S: Shape> std::fmt::Debug for PartialGraph<S> {
570 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
571 f.debug_struct("PartialGraph")
572 .field("attributes", &self.attributes)
573 .finish_non_exhaustive()
574 }
575}
576
577pub type ImportedGraph<S> = PartialGraph<S>;
578
579#[derive(Clone, Copy, Debug, PartialEq, Eq)]
583pub struct FusedExecutionConfig {
584 pub event_limit: usize,
585}
586
587impl Default for FusedExecutionConfig {
588 fn default() -> Self {
589 Self {
590 event_limit: 100_000_000,
591 }
592 }
593}
594
595#[derive(Clone, Copy, Debug, PartialEq, Eq)]
600pub struct AsyncBoundaryExecutionConfig {
601 pub fused: FusedExecutionConfig,
602 pub buffer_size: usize,
603}
604
605impl Default for AsyncBoundaryExecutionConfig {
606 fn default() -> Self {
607 Self {
608 fused: FusedExecutionConfig::default(),
609 buffer_size: 16,
610 }
611 }
612}
613
614#[derive(Clone, Debug, PartialEq, Eq)]
618pub struct FusedExecutionReport<T> {
619 pub output: Vec<T>,
620 pub events: usize,
621 pub async_boundary_crossings: usize,
622}
623
624#[derive(Clone, Debug, PartialEq, Eq)]
628pub struct FusedTerminalReport<T> {
629 pub result: T,
630 pub events: usize,
631 pub async_boundary_crossings: usize,
632}