use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
use crate::error::{EngawaError, ValidationError};
use crate::node::{Node, NodeId};
use crate::resource::{ResourceId, ResourceKind};
#[derive(Debug, Clone, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct RenderGraph {
pub resources: BTreeMap<ResourceId, ResourceKind>,
pub inputs: BTreeSet<ResourceId>,
pub outputs: BTreeSet<ResourceId>,
pub nodes: Vec<Node>,
}
impl RenderGraph {
#[must_use]
pub fn with_resource(mut self, id: impl Into<ResourceId>, kind: ResourceKind) -> Self {
self.resources.insert(id.into(), kind);
self
}
#[must_use]
pub fn with_input(mut self, id: impl Into<ResourceId>) -> Self {
self.inputs.insert(id.into());
self
}
#[must_use]
pub fn with_output(mut self, id: impl Into<ResourceId>) -> Self {
self.outputs.insert(id.into());
self
}
#[must_use]
pub fn with_node(mut self, node: Node) -> Self {
self.nodes.push(node);
self
}
pub fn compile(self) -> Result<CompiledGraph, EngawaError> {
let mut seen_nodes: BTreeSet<NodeId> = BTreeSet::new();
for n in &self.nodes {
if !seen_nodes.insert(n.id.clone()) {
return Err(ValidationError::DuplicateNode(n.id.clone()).into());
}
}
let mut producer: BTreeMap<ResourceId, NodeId> = BTreeMap::new();
for n in &self.nodes {
for out in &n.outputs {
if let Some(prior) = producer.insert(out.clone(), n.id.clone())
&& prior != n.id
{
return Err(ValidationError::MultipleWriters(out.clone()).into());
}
}
}
for n in &self.nodes {
for input in &n.inputs {
if !self.inputs.contains(input) && !producer.contains_key(input) {
return Err(ValidationError::UnboundInput {
node: n.id.clone(),
resource: input.clone(),
}
.into());
}
}
}
for out in &self.outputs {
if !producer.contains_key(out) && !self.inputs.contains(out) {
return Err(ValidationError::UnboundOutput(out.clone()).into());
}
}
let nodes_by_id: BTreeMap<NodeId, Node> =
self.nodes.iter().map(|n| (n.id.clone(), n.clone())).collect();
let mut indeg: BTreeMap<NodeId, usize> = BTreeMap::new();
for n in &self.nodes {
let mut count = 0usize;
for input in &n.inputs {
if let Some(producer_id) = producer.get(input)
&& *producer_id != n.id
{
count += 1;
}
}
indeg.insert(n.id.clone(), count);
}
let mut ready: BTreeSet<NodeId> = indeg
.iter()
.filter(|&(_, d)| *d == 0)
.map(|(id, _)| id.clone())
.collect();
let mut order: Vec<NodeId> = Vec::with_capacity(self.nodes.len());
while let Some(next_id) = ready.iter().next().cloned() {
ready.remove(&next_id);
order.push(next_id.clone());
let node = &nodes_by_id[&next_id];
for out in &node.outputs {
for other in &self.nodes {
if other.id == next_id {
continue;
}
if other.inputs.contains(out)
&& let Some(d) = indeg.get_mut(&other.id)
{
if *d > 0 {
*d -= 1;
}
if *d == 0 {
ready.insert(other.id.clone());
}
}
}
}
}
if order.len() != self.nodes.len() {
let stuck: Vec<NodeId> = indeg
.into_iter()
.filter(|(id, d)| *d > 0 && !order.contains(id))
.map(|(id, _)| id)
.collect();
return Err(ValidationError::Cycle(stuck).into());
}
Ok(CompiledGraph {
resources: self.resources,
inputs: self.inputs,
outputs: self.outputs,
execution_order: order
.into_iter()
.map(|id| nodes_by_id[&id].clone())
.collect(),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CompiledGraph {
pub resources: BTreeMap<ResourceId, ResourceKind>,
pub inputs: BTreeSet<ResourceId>,
pub outputs: BTreeSet<ResourceId>,
pub execution_order: Vec<Node>,
}
impl CompiledGraph {
#[must_use]
pub fn node_count(&self) -> usize {
self.execution_order.len()
}
pub fn iter_nodes(&self) -> impl Iterator<Item = &Node> {
self.execution_order.iter()
}
}