use crate::kvasir::node::ExecutionContext;
use crate::kvasir::resource::ResourceId;
use std::collections::HashMap;
pub type DynKvasirNode = Box<dyn ErasedKvasirNode>;
pub trait ErasedKvasirNode {
fn label(&self) -> &'static str;
fn inputs(&self) -> &[ResourceId];
fn outputs(&self) -> &[ResourceId];
fn pass_id(&self) -> super::nodes::PassId;
fn execute(&self, ctx: &mut ExecutionContext);
}
pub struct PassRegistration {
pub id: &'static str,
pub label: &'static str,
pub inputs: &'static [&'static str],
pub outputs: &'static [&'static str],
pub after: &'static [&'static str],
pub constructor: fn() -> DynKvasirNode,
}
pub struct PassRegistry {
passes: Vec<PassRegistration>,
}
impl Default for PassRegistry {
fn default() -> Self {
Self::new()
}
}
impl PassRegistry {
pub fn new() -> Self {
Self { passes: Vec::new() }
}
pub fn register(&mut self, pass: PassRegistration) {
if self.passes.iter().any(|p| p.id == pass.id) {
panic!("PassRegistry: duplicate pass id `{}`", pass.id);
}
self.passes.push(pass);
}
pub fn len(&self) -> usize {
self.passes.len()
}
pub fn is_empty(&self) -> bool {
self.passes.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = &PassRegistration> {
self.passes.iter()
}
pub fn get(&self, id: &str) -> Option<&PassRegistration> {
self.passes.iter().find(|p| p.id == id)
}
pub fn topo_sort(&self) -> Result<Vec<&'static str>, String> {
let ids: Vec<&'static str> = self.passes.iter().map(|p| p.id).collect();
let mut indegree: HashMap<&'static str, usize> =
ids.iter().map(|&id| (id, 0)).collect();
let mut graph: HashMap<&'static str, Vec<&'static str>> = HashMap::new();
for pass in &self.passes {
for after in pass.after {
if !indegree.contains_key(after) {
return Err(format!(
"PassRegistry: unknown dependency `{}` referenced by `{}`",
after, pass.id
));
}
graph.entry(after).or_default().push(pass.id);
*indegree.entry(pass.id).or_insert(0) += 1;
}
}
let mut queue: std::collections::VecDeque<&'static str> = std::collections::VecDeque::new();
for &id in &ids {
if indegree.get(id).copied().unwrap_or(0) == 0 {
queue.push_back(id);
}
}
let mut result = Vec::new();
while let Some(node) = queue.pop_front() {
result.push(node);
if let Some(edges) = graph.get(node) {
let mut ordered_edges: Vec<&'static str> = edges.clone();
ordered_edges.sort_by_key(|id| ids.iter().position(|x| x == id).unwrap_or(0));
for &next in &ordered_edges {
if let Some(d) = indegree.get_mut(next) {
*d -= 1;
if *d == 0 {
queue.push_back(next);
}
}
}
}
}
if result.len() == ids.len() {
Ok(result)
} else {
Err("PassRegistry: cycle detected".to_string())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kvasir::nodes::PassId;
struct DummyNode(&'static str);
impl crate::kvasir::node::KvasirNode for DummyNode {
fn label(&self) -> &'static str {
self.0
}
fn inputs(&self) -> &[ResourceId] {
&[]
}
fn outputs(&self) -> &[ResourceId] {
&[]
}
fn pass_id(&self) -> PassId {
PassId::Composite
}
fn execute(&self, _ctx: &mut ExecutionContext) {}
}
impl ErasedKvasirNode for DummyNode {
fn label(&self) -> &'static str {
self.0
}
fn inputs(&self) -> &[ResourceId] {
&[]
}
fn outputs(&self) -> &[ResourceId] {
&[]
}
fn pass_id(&self) -> PassId {
PassId::Composite
}
fn execute(&self, _ctx: &mut ExecutionContext) {}
}
fn make_pass(id: &'static str, after: &'static [&'static str]) -> PassRegistration {
PassRegistration {
id,
label: id,
inputs: &[],
outputs: &[],
after,
constructor: || {
Box::new(DummyNode("dummy")) as DynKvasirNode
},
}
}
#[test]
fn test_empty_registry() {
let reg = PassRegistry::new();
assert_eq!(reg.len(), 0);
assert!(reg.is_empty());
}
#[test]
fn test_register_passes() {
let mut reg = PassRegistry::new();
reg.register(make_pass("a", &[]));
reg.register(make_pass("b", &[]));
assert_eq!(reg.len(), 2);
}
#[test]
fn test_lookup_by_id() {
let mut reg = PassRegistry::new();
reg.register(make_pass("geometry", &[]));
reg.register(make_pass("ui", &[]));
assert!(reg.get("geometry").is_some());
assert!(reg.get("missing").is_none());
}
#[test]
fn test_iter() {
let mut reg = PassRegistry::new();
reg.register(make_pass("a", &[]));
reg.register(make_pass("b", &[]));
let ids: Vec<&str> = reg.iter().map(|p| p.id).collect();
assert_eq!(ids, vec!["a", "b"]);
}
#[test]
fn test_topological_order() {
let mut reg = PassRegistry::new();
reg.register(make_pass("ui", &["geometry"]));
reg.register(make_pass("composite", &["ui"]));
reg.register(make_pass("geometry", &[]));
let order = reg.topo_sort().expect("no cycle");
let pos = |id: &str| order.iter().position(|&x| x == id).unwrap();
assert!(pos("geometry") < pos("ui"));
assert!(pos("ui") < pos("composite"));
}
#[test]
fn test_topological_no_constraints_preserves_insertion_order() {
let mut reg = PassRegistry::new();
reg.register(make_pass("a", &[]));
reg.register(make_pass("b", &[]));
reg.register(make_pass("c", &[]));
let order = reg.topo_sort().unwrap();
assert_eq!(order, vec!["a", "b", "c"]);
}
#[test]
fn test_topological_unknown_dependency_errors() {
let mut reg = PassRegistry::new();
reg.register(make_pass("ui", &["ghost"]));
let err = reg.topo_sort().unwrap_err();
assert!(err.contains("ghost"));
}
#[test]
#[should_panic]
fn test_duplicate_registration_panics() {
let mut reg = PassRegistry::new();
reg.register(make_pass("dup", &[]));
reg.register(make_pass("dup", &[]));
}
}