use std::collections::HashMap;
use std::collections::HashSet;
use anyhow::Result;
use crate::sdf::Path;
use crate::usd::{PrimPredicate, Stage};
#[derive(Debug, Clone, Default)]
pub struct ConnectionGraph {
forward: HashMap<Path, Vec<Path>>,
reverse: HashMap<Path, Vec<Path>>,
}
impl ConnectionGraph {
pub fn from_stage(stage: &Stage) -> Result<Self> {
let mut graph = ConnectionGraph::default();
let mut first_err: Result<()> = Ok(());
stage.traverse(PrimPredicate::DEFAULT_PROXIES, |prim| {
if first_err.is_err() {
return;
}
if let Err(e) = graph.index_prim(stage, prim) {
first_err = Err(e);
}
})?;
first_err?;
Ok(graph)
}
fn index_prim(&mut self, stage: &Stage, prim: &Path) -> Result<()> {
for prop in stage.prim_at(prim.clone()).property_names()? {
let attr = prim.append_property(&prop)?;
let sources = stage.attribute_at(attr.clone()).connections()?;
if sources.is_empty() {
continue;
}
for source in &sources {
self.reverse.entry(source.clone()).or_default().push(attr.clone());
}
self.forward.insert(attr, sources);
}
Ok(())
}
pub fn sources(&self, attr: &Path) -> &[Path] {
self.forward.get(attr).map_or(&[], Vec::as_slice)
}
pub fn sinks(&self, attr: &Path) -> &[Path] {
self.reverse.get(attr).map_or(&[], Vec::as_slice)
}
pub fn is_connected(&self, attr: &Path) -> bool {
self.forward.contains_key(attr)
}
pub fn len(&self) -> usize {
self.forward.len()
}
pub fn is_empty(&self) -> bool {
self.forward.is_empty()
}
pub fn edges(&self) -> impl Iterator<Item = (&Path, &Path)> {
self.forward
.iter()
.flat_map(|(from, tos)| tos.iter().map(move |to| (from, to)))
}
pub fn resolve_chain(&self, attr: &Path) -> Vec<Path> {
let mut terminals = Vec::new();
let mut visited = HashSet::new();
let mut stack: Vec<&Path> = vec![attr];
while let Some(node) = stack.pop() {
if !visited.insert(node.clone()) {
continue;
}
match self.forward.get(node) {
None => {
terminals.push(node.clone());
}
Some(sources) => stack.extend(sources.iter().rev()),
}
}
terminals
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sdf;
fn chain_stage() -> Result<Stage> {
let stage = Stage::builder().in_memory("anon.usda")?;
stage.define_prim("/G")?.set_type_name("Scope")?;
let c = stage.define_prim("/G/C")?.set_type_name("Shader")?;
c.create_attribute("outputs:out", "float")?;
let b = stage.define_prim("/G/B")?.set_type_name("Shader")?;
b.create_attribute("outputs:out", "float")?;
b.create_attribute("inputs:in", "float")?
.set_connections([sdf::path("/G/C.outputs:out")?])?;
let a = stage.define_prim("/G/A")?.set_type_name("Shader")?;
a.create_attribute("inputs:in", "float")?
.set_connections([sdf::path("/G/B.outputs:out")?])?;
Ok(stage)
}
#[test]
fn sources_and_sinks() -> Result<()> {
let stage = chain_stage()?;
let graph = ConnectionGraph::from_stage(&stage)?;
assert_eq!(
graph.sources(&sdf::path("/G/A.inputs:in")?),
&[sdf::path("/G/B.outputs:out")?]
);
assert!(graph.sinks(&sdf::path("/G/A.inputs:in")?).is_empty());
assert_eq!(
graph.sinks(&sdf::path("/G/B.outputs:out")?),
&[sdf::path("/G/A.inputs:in")?]
);
assert!(graph.is_connected(&sdf::path("/G/A.inputs:in")?));
assert!(!graph.is_connected(&sdf::path("/G/C.outputs:out")?));
assert_eq!(graph.len(), 2);
Ok(())
}
#[test]
fn resolve_chain_to_terminal() -> Result<()> {
let stage = chain_stage()?;
let graph = ConnectionGraph::from_stage(&stage)?;
let terminals = graph.resolve_chain(&sdf::path("/G/A.inputs:in")?);
assert_eq!(terminals, vec![sdf::path("/G/B.outputs:out")?]);
Ok(())
}
#[test]
fn resolve_chain_deep() -> Result<()> {
let stage = Stage::builder().in_memory("anon.usda")?;
let host = stage.define_prim("/Host")?.set_type_name("Shader")?;
const N: usize = 2_000;
for i in 0..N - 1 {
host.create_attribute(&format!("inputs:n{i}"), "float")?
.set_connections([sdf::path(format!("/Host.inputs:n{}", i + 1))?])?;
}
host.create_attribute(&format!("inputs:n{}", N - 1), "float")?;
let graph = ConnectionGraph::from_stage(&stage)?;
let terminals = graph.resolve_chain(&sdf::path("/Host.inputs:n0")?);
assert_eq!(terminals, vec![sdf::path(format!("/Host.inputs:n{}", N - 1))?]);
Ok(())
}
#[test]
fn cycle_is_broken() -> Result<()> {
let stage = Stage::builder().in_memory("anon.usda")?;
let p = stage.define_prim("/P")?.set_type_name("Shader")?;
p.create_attribute("inputs:a", "float")?
.set_connections([sdf::path("/P.inputs:b")?])?;
p.create_attribute("inputs:b", "float")?
.set_connections([sdf::path("/P.inputs:a")?])?;
let graph = ConnectionGraph::from_stage(&stage)?;
let terminals = graph.resolve_chain(&sdf::path("/P.inputs:a")?);
assert!(terminals.is_empty());
Ok(())
}
#[test]
fn empty_stage_has_no_edges() -> Result<()> {
let stage = Stage::builder().in_memory("anon.usda")?;
stage.define_prim("/X")?.set_type_name("Scope")?;
let graph = ConnectionGraph::from_stage(&stage)?;
assert!(graph.is_empty());
assert_eq!(graph.edges().count(), 0);
Ok(())
}
}