Skip to main content

ezu_graph/
graph.rs

1//! The DAG itself: building, type-checking, topology, pad propagation.
2
3use std::collections::VecDeque;
4
5use indexmap::IndexMap;
6
7use crate::node::Node;
8use crate::port::PortKind;
9
10/// Identifier for a node in the [`Graph`]. Matches the key used in the
11/// style JSON's `nodes` object (e.g. `"water_paint"`).
12pub type NodeId = String;
13
14/// Compact internal index assigned during graph build. Stable for the
15/// lifetime of the graph; used by topo-sorted operations.
16pub type NodeIx = usize;
17
18#[derive(Debug, thiserror::Error)]
19pub enum BuildError {
20    #[error("unknown node reference `{from}` -> `{to}`")]
21    UnknownRef { from: NodeId, to: NodeId },
22
23    #[error("node `{node}` has no input port named `{port}`")]
24    UnknownPort { node: NodeId, port: String },
25
26    #[error("port `{node}.{port}` already connected")]
27    DuplicateEdge { node: NodeId, port: String },
28
29    #[error(
30        "type mismatch on `{node}.{port}`: expected one of [{}], source `{src}` produces {got}",
31        accepts.iter().map(|k| k.to_string()).collect::<Vec<_>>().join(", ")
32    )]
33    TypeMismatch {
34        node: NodeId,
35        port: String,
36        src: NodeId,
37        accepts: Vec<PortKind>,
38        got: PortKind,
39    },
40
41    #[error("required port `{node}.{port}` is not connected")]
42    MissingInput { node: NodeId, port: String },
43
44    #[error("cycle detected involving node `{0}`")]
45    Cycle(NodeId),
46
47    #[error("output node `{0}` is not in the graph")]
48    UnknownOutput(NodeId),
49
50    #[error("graph has no output node")]
51    NoOutput,
52
53    #[error(
54        "output node `{node}` produces `{got}`, but the document output must produce `raster` (canvas-padded). Pipe a sprite through `place`, `tiling`, or `stamp` first."
55    )]
56    OutputKindMismatch { node: NodeId, got: PortKind },
57
58    #[error("required pad ({required}) on node `{node}` exceeds limit ({limit})")]
59    PadExceeded {
60        node: NodeId,
61        required: u32,
62        limit: u32,
63    },
64}
65
66/// An edge in the DAG: `src.output` flows into `dst.inputs()[port_ix]`.
67#[derive(Debug, Clone, Copy)]
68pub struct Edge {
69    pub src: NodeIx,
70    pub dst: NodeIx,
71    pub dst_port: usize,
72}
73
74/// A built, type-checked DAG. Tile-independent; build once per style
75/// and evaluate many times.
76pub struct Graph {
77    nodes: IndexMap<NodeId, Box<dyn Node>>,
78    /// Per-node, per-input-port edge source (None if unconnected and
79    /// the port was optional).
80    incoming: Vec<Vec<Option<NodeIx>>>,
81    /// Adjacency for downstream walks: outgoing[src] -> list of dst.
82    outgoing: Vec<Vec<NodeIx>>,
83    /// Output node index.
84    output: NodeIx,
85    /// Topological order, output last.
86    topo: Vec<NodeIx>,
87    /// Resolved output [`PortKind`] for every node, indexed by [`NodeIx`].
88    /// Polymorphic nodes (e.g. `blur` accepting both `Raster` and
89    /// `Sprite`) have their actual output kind decided here at build
90    /// time based on their connected inputs.
91    output_kinds: Vec<PortKind>,
92}
93
94/// Maximum allowed pad propagated to any node, in pixels. Prevents
95/// runaway blurs from demanding multi-tile buffers.
96pub const MAX_PAD: u32 = 256;
97
98/// Builder for constructing a [`Graph`] programmatically. The style
99/// parser will drive this from JSON later; tests use it directly.
100pub struct GraphBuilder {
101    nodes: IndexMap<NodeId, Box<dyn Node>>,
102    /// Pending edges, recorded by name; resolved at `build()` time.
103    edges: Vec<EdgeSpec>,
104    output: Option<NodeId>,
105}
106
107struct EdgeSpec {
108    src: NodeId,
109    dst: NodeId,
110    dst_port: String,
111}
112
113impl GraphBuilder {
114    pub fn new() -> Self {
115        Self {
116            nodes: IndexMap::new(),
117            edges: Vec::new(),
118            output: None,
119        }
120    }
121
122    pub fn add_node(&mut self, id: impl Into<NodeId>, node: Box<dyn Node>) -> &mut Self {
123        self.nodes.insert(id.into(), node);
124        self
125    }
126
127    pub fn connect(
128        &mut self,
129        src: impl Into<NodeId>,
130        dst: impl Into<NodeId>,
131        dst_port: impl Into<String>,
132    ) -> &mut Self {
133        self.edges.push(EdgeSpec {
134            src: src.into(),
135            dst: dst.into(),
136            dst_port: dst_port.into(),
137        });
138        self
139    }
140
141    pub fn set_output(&mut self, id: impl Into<NodeId>) -> &mut Self {
142        self.output = Some(id.into());
143        self
144    }
145
146    pub fn build(self) -> Result<Graph, BuildError> {
147        let n = self.nodes.len();
148        let mut incoming: Vec<Vec<Option<NodeIx>>> = self
149            .nodes
150            .values()
151            .map(|node| vec![None; node.inputs().len()])
152            .collect();
153        let mut outgoing: Vec<Vec<NodeIx>> = vec![Vec::new(); n];
154
155        let ix_of = |id: &str| -> Option<NodeIx> { self.nodes.get_index_of(id) };
156
157        // Pass 1: wire edges (no type check yet — output kinds may be
158        // polymorphic and only resolvable in topo order).
159        for edge in &self.edges {
160            let src_ix = ix_of(&edge.src).ok_or_else(|| BuildError::UnknownRef {
161                from: edge.src.clone(),
162                to: edge.dst.clone(),
163            })?;
164            let dst_ix = ix_of(&edge.dst).ok_or_else(|| BuildError::UnknownRef {
165                from: edge.src.clone(),
166                to: edge.dst.clone(),
167            })?;
168
169            let (_, dst_node) = self
170                .nodes
171                .get_index(dst_ix)
172                .expect("dst_ix came from ix_of and is in range");
173            let port_ix = dst_node
174                .inputs()
175                .iter()
176                .position(|p| p.name == edge.dst_port)
177                .ok_or_else(|| BuildError::UnknownPort {
178                    node: edge.dst.clone(),
179                    port: edge.dst_port.clone(),
180                })?;
181
182            if incoming[dst_ix][port_ix].is_some() {
183                return Err(BuildError::DuplicateEdge {
184                    node: edge.dst.clone(),
185                    port: edge.dst_port.clone(),
186                });
187            }
188
189            incoming[dst_ix][port_ix] = Some(src_ix);
190            outgoing[src_ix].push(dst_ix);
191        }
192
193        // Required-port check.
194        for (ix, (id, node)) in self.nodes.iter().enumerate() {
195            for (port_ix, port) in node.inputs().iter().enumerate() {
196                if !port.optional && incoming[ix][port_ix].is_none() {
197                    return Err(BuildError::MissingInput {
198                        node: id.clone(),
199                        port: port.name.to_string(),
200                    });
201                }
202            }
203        }
204
205        let topo = topo_sort(n, &incoming, &self.nodes)?;
206
207        // Pass 2: walk topo order, resolve each node's output kind from
208        // its (already-resolved) upstream kinds, and check the upstream
209        // kind against each input port's `accepts` list.
210        let mut output_kinds: Vec<PortKind> = vec![PortKind::Raster; n];
211        for &ix in &topo {
212            let (id, node) = self.nodes.get_index(ix).expect("ix from topo is in range");
213            let specs = node.inputs();
214            let mut input_kinds: Vec<Option<PortKind>> = Vec::with_capacity(specs.len());
215            for (port_ix, spec) in specs.iter().enumerate() {
216                match incoming[ix][port_ix] {
217                    Some(src_ix) => {
218                        let src_kind = output_kinds[src_ix];
219                        if !spec.accepts_kind(src_kind) {
220                            let (src_id, _) = self
221                                .nodes
222                                .get_index(src_ix)
223                                .expect("src_ix from incoming is in range");
224                            return Err(BuildError::TypeMismatch {
225                                node: id.clone(),
226                                port: spec.name.to_string(),
227                                src: src_id.clone(),
228                                accepts: spec.accepts.to_vec(),
229                                got: src_kind,
230                            });
231                        }
232                        input_kinds.push(Some(src_kind));
233                    }
234                    None => input_kinds.push(None),
235                }
236            }
237            output_kinds[ix] = node.output(&input_kinds);
238        }
239
240        let output_id = self.output.ok_or(BuildError::NoOutput)?;
241        let output_ix = ix_of(&output_id).ok_or(BuildError::UnknownOutput(output_id.clone()))?;
242        // Document output must be a canvas-padded raster — anything
243        // smaller (e.g. a raw `Sprite`) will alias badly through the
244        // host's `raster_to_png` crop. Catch this at build time.
245        let output_kind = output_kinds[output_ix];
246        if output_kind != PortKind::Raster {
247            return Err(BuildError::OutputKindMismatch {
248                node: output_id.clone(),
249                got: output_kind,
250            });
251        }
252
253        Ok(Graph {
254            nodes: self.nodes,
255            incoming,
256            outgoing,
257            output: output_ix,
258            topo,
259            output_kinds,
260        })
261    }
262}
263
264impl Default for GraphBuilder {
265    fn default() -> Self {
266        Self::new()
267    }
268}
269
270fn topo_sort(
271    n: usize,
272    incoming: &[Vec<Option<NodeIx>>],
273    nodes: &IndexMap<NodeId, Box<dyn Node>>,
274) -> Result<Vec<NodeIx>, BuildError> {
275    // Kahn's algorithm over the unique upstream set per node.
276    let mut indegree: Vec<usize> = incoming
277        .iter()
278        .map(|ports| {
279            let mut srcs: Vec<NodeIx> = ports.iter().filter_map(|p| *p).collect();
280            srcs.sort_unstable();
281            srcs.dedup();
282            srcs.len()
283        })
284        .collect();
285
286    // Reverse adjacency: for each src, the unique dsts that depend on it.
287    let mut rev: Vec<Vec<NodeIx>> = vec![Vec::new(); n];
288    for (dst, ports) in incoming.iter().enumerate() {
289        let mut srcs: Vec<NodeIx> = ports.iter().filter_map(|p| *p).collect();
290        srcs.sort_unstable();
291        srcs.dedup();
292        for src in srcs {
293            rev[src].push(dst);
294        }
295    }
296
297    let mut queue: VecDeque<NodeIx> = (0..n).filter(|&i| indegree[i] == 0).collect();
298    let mut order = Vec::with_capacity(n);
299    while let Some(ix) = queue.pop_front() {
300        order.push(ix);
301        for &dst in &rev[ix] {
302            indegree[dst] -= 1;
303            if indegree[dst] == 0 {
304                queue.push_back(dst);
305            }
306        }
307    }
308
309    if order.len() != n {
310        // `order.len() != n` means at least one node still has incoming
311        // edges (otherwise topo would have queued it). Find one such
312        // node to name in the error.
313        let bad = (0..n)
314            .find(|&i| indegree[i] != 0)
315            .expect("order.len() != n implies some indegree is non-zero");
316        let (id, _) = nodes.get_index(bad).expect("bad < n is within nodes range");
317        return Err(BuildError::Cycle(id.clone()));
318    }
319    Ok(order)
320}
321
322impl std::fmt::Debug for Graph {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        let ids: Vec<&str> = self.nodes.keys().map(String::as_str).collect();
325        f.debug_struct("Graph")
326            .field("nodes", &ids)
327            .field("output", &self.node_id(self.output))
328            .field(
329                "topo",
330                &self
331                    .topo
332                    .iter()
333                    .map(|&i| self.node_id(i))
334                    .collect::<Vec<_>>(),
335            )
336            .finish()
337    }
338}
339
340impl Graph {
341    /// Number of nodes.
342    pub fn len(&self) -> usize {
343        self.nodes.len()
344    }
345
346    pub fn is_empty(&self) -> bool {
347        self.nodes.is_empty()
348    }
349
350    pub fn output(&self) -> NodeIx {
351        self.output
352    }
353
354    /// Topological order; output node is last.
355    pub fn topo_order(&self) -> &[NodeIx] {
356        &self.topo
357    }
358
359    pub fn node(&self, ix: NodeIx) -> &dyn Node {
360        self.nodes
361            .get_index(ix)
362            .expect("NodeIx is always within self.nodes range")
363            .1
364            .as_ref()
365    }
366
367    pub fn node_id(&self, ix: NodeIx) -> &str {
368        self.nodes
369            .get_index(ix)
370            .expect("NodeIx is always within self.nodes range")
371            .0
372    }
373
374    /// Upstream nodes feeding `ix`, deduplicated.
375    pub fn upstream(&self, ix: NodeIx) -> impl Iterator<Item = NodeIx> + '_ {
376        let mut srcs: Vec<NodeIx> = self.incoming[ix].iter().filter_map(|p| *p).collect();
377        srcs.sort_unstable();
378        srcs.dedup();
379        srcs.into_iter()
380    }
381
382    /// Downstream nodes consuming `ix`'s output (may contain duplicates
383    /// if the same node connects multiple of its input ports to `ix`).
384    pub fn downstream(&self, ix: NodeIx) -> &[NodeIx] {
385        &self.outgoing[ix]
386    }
387
388    /// The source feeding `node.inputs()[port_ix]`, if connected.
389    pub fn incoming(&self, ix: NodeIx, port_ix: usize) -> Option<NodeIx> {
390        self.incoming[ix][port_ix]
391    }
392
393    /// Resolved output [`PortKind`] for `ix`. Decided at build time;
394    /// polymorphic nodes' kind is fixed once the graph is built.
395    pub fn output_kind(&self, ix: NodeIx) -> PortKind {
396        self.output_kinds[ix]
397    }
398
399    /// Group nodes into evaluation "levels". A node's level is one more
400    /// than the maximum level of its inputs (sources are at level 0).
401    /// All nodes in the same level have no edges between them and can
402    /// be evaluated in parallel. Returned as `levels[node_ix] = depth`.
403    pub fn compute_levels(&self) -> Vec<u32> {
404        let mut levels = vec![0u32; self.len()];
405        for &ix in &self.topo {
406            let max_up = self.upstream(ix).map(|s| levels[s] + 1).max().unwrap_or(0);
407            levels[ix] = max_up;
408        }
409        levels
410    }
411
412    /// Bucket nodes by level, preserving topo order within each bucket
413    /// for determinism.
414    pub fn level_buckets(&self) -> Vec<Vec<NodeIx>> {
415        let levels = self.compute_levels();
416        let max_level = levels.iter().copied().max().unwrap_or(0);
417        let mut buckets: Vec<Vec<NodeIx>> = vec![Vec::new(); (max_level + 1) as usize];
418        for &ix in &self.topo {
419            buckets[levels[ix] as usize].push(ix);
420        }
421        buckets
422    }
423
424    /// Compute the canvas padding each node must supply, given the
425    /// document-level `pad` requested at the output.
426    pub fn compute_pad(&self, doc_pad: u32) -> Result<Vec<u32>, BuildError> {
427        let mut required = vec![0u32; self.len()];
428        required[self.output] = doc_pad;
429        for &ix in self.topo.iter().rev() {
430            let down = required[ix];
431            let up = self.node(ix).required_pad(down);
432            if up > MAX_PAD {
433                return Err(BuildError::PadExceeded {
434                    node: self.node_id(ix).to_string(),
435                    required: up,
436                    limit: MAX_PAD,
437                });
438            }
439            for src in self.upstream(ix) {
440                required[src] = required[src].max(up);
441            }
442        }
443        Ok(required)
444    }
445}