use std::collections::VecDeque;
use indexmap::IndexMap;
use crate::node::Node;
use crate::port::PortKind;
pub type NodeId = String;
pub type NodeIx = usize;
#[derive(Debug, thiserror::Error)]
pub enum BuildError {
#[error("unknown node reference `{from}` -> `{to}`")]
UnknownRef { from: NodeId, to: NodeId },
#[error("node `{node}` has no input port named `{port}`")]
UnknownPort { node: NodeId, port: String },
#[error("port `{node}.{port}` already connected")]
DuplicateEdge { node: NodeId, port: String },
#[error(
"type mismatch on `{node}.{port}`: expected one of [{}], source `{src}` produces {got}",
accepts.iter().map(|k| k.to_string()).collect::<Vec<_>>().join(", ")
)]
TypeMismatch {
node: NodeId,
port: String,
src: NodeId,
accepts: Vec<PortKind>,
got: PortKind,
},
#[error("required port `{node}.{port}` is not connected")]
MissingInput { node: NodeId, port: String },
#[error("cycle detected involving node `{0}`")]
Cycle(NodeId),
#[error("output node `{0}` is not in the graph")]
UnknownOutput(NodeId),
#[error("graph has no output node")]
NoOutput,
#[error(
"output node `{node}` produces `{got}`, but the document output must produce `raster` (canvas-padded). Pipe a sprite through `place`, `tiling`, or `stamp` first."
)]
OutputKindMismatch { node: NodeId, got: PortKind },
#[error("required pad ({required}) on node `{node}` exceeds limit ({limit})")]
PadExceeded {
node: NodeId,
required: u32,
limit: u32,
},
}
#[derive(Debug, Clone, Copy)]
pub struct Edge {
pub src: NodeIx,
pub dst: NodeIx,
pub dst_port: usize,
}
pub struct Graph {
nodes: IndexMap<NodeId, Box<dyn Node>>,
incoming: Vec<Vec<Option<NodeIx>>>,
outgoing: Vec<Vec<NodeIx>>,
output: NodeIx,
topo: Vec<NodeIx>,
output_kinds: Vec<PortKind>,
}
pub const MAX_PAD: u32 = 256;
pub struct GraphBuilder {
nodes: IndexMap<NodeId, Box<dyn Node>>,
edges: Vec<EdgeSpec>,
output: Option<NodeId>,
}
struct EdgeSpec {
src: NodeId,
dst: NodeId,
dst_port: String,
}
impl GraphBuilder {
pub fn new() -> Self {
Self {
nodes: IndexMap::new(),
edges: Vec::new(),
output: None,
}
}
pub fn add_node(&mut self, id: impl Into<NodeId>, node: Box<dyn Node>) -> &mut Self {
self.nodes.insert(id.into(), node);
self
}
pub fn connect(
&mut self,
src: impl Into<NodeId>,
dst: impl Into<NodeId>,
dst_port: impl Into<String>,
) -> &mut Self {
self.edges.push(EdgeSpec {
src: src.into(),
dst: dst.into(),
dst_port: dst_port.into(),
});
self
}
pub fn set_output(&mut self, id: impl Into<NodeId>) -> &mut Self {
self.output = Some(id.into());
self
}
pub fn build(self) -> Result<Graph, BuildError> {
let n = self.nodes.len();
let mut incoming: Vec<Vec<Option<NodeIx>>> = self
.nodes
.values()
.map(|node| vec![None; node.inputs().len()])
.collect();
let mut outgoing: Vec<Vec<NodeIx>> = vec![Vec::new(); n];
let ix_of = |id: &str| -> Option<NodeIx> { self.nodes.get_index_of(id) };
for edge in &self.edges {
let src_ix = ix_of(&edge.src).ok_or_else(|| BuildError::UnknownRef {
from: edge.src.clone(),
to: edge.dst.clone(),
})?;
let dst_ix = ix_of(&edge.dst).ok_or_else(|| BuildError::UnknownRef {
from: edge.src.clone(),
to: edge.dst.clone(),
})?;
let (_, dst_node) = self
.nodes
.get_index(dst_ix)
.expect("dst_ix came from ix_of and is in range");
let port_ix = dst_node
.inputs()
.iter()
.position(|p| p.name == edge.dst_port)
.ok_or_else(|| BuildError::UnknownPort {
node: edge.dst.clone(),
port: edge.dst_port.clone(),
})?;
if incoming[dst_ix][port_ix].is_some() {
return Err(BuildError::DuplicateEdge {
node: edge.dst.clone(),
port: edge.dst_port.clone(),
});
}
incoming[dst_ix][port_ix] = Some(src_ix);
outgoing[src_ix].push(dst_ix);
}
for (ix, (id, node)) in self.nodes.iter().enumerate() {
for (port_ix, port) in node.inputs().iter().enumerate() {
if !port.optional && incoming[ix][port_ix].is_none() {
return Err(BuildError::MissingInput {
node: id.clone(),
port: port.name.to_string(),
});
}
}
}
let topo = topo_sort(n, &incoming, &self.nodes)?;
let mut output_kinds: Vec<PortKind> = vec![PortKind::Raster; n];
for &ix in &topo {
let (id, node) = self.nodes.get_index(ix).expect("ix from topo is in range");
let specs = node.inputs();
let mut input_kinds: Vec<Option<PortKind>> = Vec::with_capacity(specs.len());
for (port_ix, spec) in specs.iter().enumerate() {
match incoming[ix][port_ix] {
Some(src_ix) => {
let src_kind = output_kinds[src_ix];
if !spec.accepts_kind(src_kind) {
let (src_id, _) = self
.nodes
.get_index(src_ix)
.expect("src_ix from incoming is in range");
return Err(BuildError::TypeMismatch {
node: id.clone(),
port: spec.name.to_string(),
src: src_id.clone(),
accepts: spec.accepts.to_vec(),
got: src_kind,
});
}
input_kinds.push(Some(src_kind));
}
None => input_kinds.push(None),
}
}
output_kinds[ix] = node.output(&input_kinds);
}
let output_id = self.output.ok_or(BuildError::NoOutput)?;
let output_ix = ix_of(&output_id).ok_or(BuildError::UnknownOutput(output_id.clone()))?;
let output_kind = output_kinds[output_ix];
if output_kind != PortKind::Raster {
return Err(BuildError::OutputKindMismatch {
node: output_id.clone(),
got: output_kind,
});
}
Ok(Graph {
nodes: self.nodes,
incoming,
outgoing,
output: output_ix,
topo,
output_kinds,
})
}
}
impl Default for GraphBuilder {
fn default() -> Self {
Self::new()
}
}
fn topo_sort(
n: usize,
incoming: &[Vec<Option<NodeIx>>],
nodes: &IndexMap<NodeId, Box<dyn Node>>,
) -> Result<Vec<NodeIx>, BuildError> {
let mut indegree: Vec<usize> = incoming
.iter()
.map(|ports| {
let mut srcs: Vec<NodeIx> = ports.iter().filter_map(|p| *p).collect();
srcs.sort_unstable();
srcs.dedup();
srcs.len()
})
.collect();
let mut rev: Vec<Vec<NodeIx>> = vec![Vec::new(); n];
for (dst, ports) in incoming.iter().enumerate() {
let mut srcs: Vec<NodeIx> = ports.iter().filter_map(|p| *p).collect();
srcs.sort_unstable();
srcs.dedup();
for src in srcs {
rev[src].push(dst);
}
}
let mut queue: VecDeque<NodeIx> = (0..n).filter(|&i| indegree[i] == 0).collect();
let mut order = Vec::with_capacity(n);
while let Some(ix) = queue.pop_front() {
order.push(ix);
for &dst in &rev[ix] {
indegree[dst] -= 1;
if indegree[dst] == 0 {
queue.push_back(dst);
}
}
}
if order.len() != n {
let bad = (0..n)
.find(|&i| indegree[i] != 0)
.expect("order.len() != n implies some indegree is non-zero");
let (id, _) = nodes.get_index(bad).expect("bad < n is within nodes range");
return Err(BuildError::Cycle(id.clone()));
}
Ok(order)
}
impl std::fmt::Debug for Graph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let ids: Vec<&str> = self.nodes.keys().map(String::as_str).collect();
f.debug_struct("Graph")
.field("nodes", &ids)
.field("output", &self.node_id(self.output))
.field(
"topo",
&self
.topo
.iter()
.map(|&i| self.node_id(i))
.collect::<Vec<_>>(),
)
.finish()
}
}
impl Graph {
pub fn len(&self) -> usize {
self.nodes.len()
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn output(&self) -> NodeIx {
self.output
}
pub fn topo_order(&self) -> &[NodeIx] {
&self.topo
}
pub fn node(&self, ix: NodeIx) -> &dyn Node {
self.nodes
.get_index(ix)
.expect("NodeIx is always within self.nodes range")
.1
.as_ref()
}
pub fn node_id(&self, ix: NodeIx) -> &str {
self.nodes
.get_index(ix)
.expect("NodeIx is always within self.nodes range")
.0
}
pub fn upstream(&self, ix: NodeIx) -> impl Iterator<Item = NodeIx> + '_ {
let mut srcs: Vec<NodeIx> = self.incoming[ix].iter().filter_map(|p| *p).collect();
srcs.sort_unstable();
srcs.dedup();
srcs.into_iter()
}
pub fn downstream(&self, ix: NodeIx) -> &[NodeIx] {
&self.outgoing[ix]
}
pub fn incoming(&self, ix: NodeIx, port_ix: usize) -> Option<NodeIx> {
self.incoming[ix][port_ix]
}
pub fn output_kind(&self, ix: NodeIx) -> PortKind {
self.output_kinds[ix]
}
pub fn compute_levels(&self) -> Vec<u32> {
let mut levels = vec![0u32; self.len()];
for &ix in &self.topo {
let max_up = self.upstream(ix).map(|s| levels[s] + 1).max().unwrap_or(0);
levels[ix] = max_up;
}
levels
}
pub fn level_buckets(&self) -> Vec<Vec<NodeIx>> {
let levels = self.compute_levels();
let max_level = levels.iter().copied().max().unwrap_or(0);
let mut buckets: Vec<Vec<NodeIx>> = vec![Vec::new(); (max_level + 1) as usize];
for &ix in &self.topo {
buckets[levels[ix] as usize].push(ix);
}
buckets
}
pub fn compute_pad(&self, doc_pad: u32) -> Result<Vec<u32>, BuildError> {
let mut required = vec![0u32; self.len()];
required[self.output] = doc_pad;
for &ix in self.topo.iter().rev() {
let down = required[ix];
let up = self.node(ix).required_pad(down);
if up > MAX_PAD {
return Err(BuildError::PadExceeded {
node: self.node_id(ix).to_string(),
required: up,
limit: MAX_PAD,
});
}
for src in self.upstream(ix) {
required[src] = required[src].max(up);
}
}
Ok(required)
}
}