use std::collections::{HashMap, HashSet};
use crate::error::{Error, Result};
use crate::fieldgraph::{FieldId, FieldProgram, NodeMeta, NodeSpec};
use crate::prelude::FieldGraphSpec;
#[derive(Clone, Debug, Default)]
pub struct CompileOptions {
pub force_bake: HashSet<FieldId>,
}
pub struct FieldGraphCompiler;
impl FieldGraphCompiler {
pub fn compile(spec: &FieldGraphSpec, opts: &CompileOptions) -> Result<FieldProgram> {
let mut nodes: HashMap<FieldId, NodeMeta> = HashMap::new();
for (id, node_spec) in &spec.nodes {
for input in node_spec.inputs() {
if !spec.nodes.contains_key(input) {
return Err(Error::Compile(format!(
"Node '{id}' references unknown input '{input}'"
)));
}
}
validate_node_inputs(id, node_spec)?;
let force_bake = opts.force_bake.contains(id);
nodes.insert(
id.clone(),
NodeMeta {
id: id.clone(),
spec: node_spec.clone(),
force_bake,
semantics: spec.semantics.get(id).cloned(),
},
);
}
let topo = topo_sort(&nodes)?;
Ok(FieldProgram { nodes, topo })
}
}
fn validate_node_inputs(id: &str, node_spec: &NodeSpec) -> Result<()> {
let inputs = node_spec.inputs();
let ensure_at_least_one = |variant: &str| {
if inputs.is_empty() {
Err(Error::Compile(format!(
"Node '{id}' ({variant}) requires at least one input"
)))
} else {
Ok(())
}
};
let ensure_exactly_one = |variant: &str| {
if inputs.len() != 1 {
Err(Error::Compile(format!(
"Node '{}' ({}) requires exactly one input but found {}",
id,
variant,
inputs.len()
)))
} else {
Ok(())
}
};
match node_spec {
NodeSpec::Constant { .. } | NodeSpec::Texture { .. } => Ok(()),
NodeSpec::Add { .. } => ensure_at_least_one("Add"),
NodeSpec::Sub { .. } => ensure_at_least_one("Sub"),
NodeSpec::Mul { .. } => ensure_at_least_one("Mul"),
NodeSpec::Min { .. } => ensure_at_least_one("Min"),
NodeSpec::Max { .. } => ensure_at_least_one("Max"),
NodeSpec::Invert { .. } => ensure_exactly_one("Invert"),
NodeSpec::Scale { .. } => ensure_exactly_one("Scale"),
NodeSpec::Clamp { .. } => ensure_exactly_one("Clamp"),
NodeSpec::SmoothStep { .. } => ensure_exactly_one("SmoothStep"),
NodeSpec::Pow { .. } => ensure_exactly_one("Pow"),
NodeSpec::EdtNormalize { .. } => ensure_exactly_one("EdtNormalize"),
}
}
fn topo_sort(nodes: &HashMap<FieldId, NodeMeta>) -> Result<Vec<FieldId>> {
let mut indeg: HashMap<&str, usize> = HashMap::new();
let mut dependents: HashMap<&str, HashMap<&str, usize>> = HashMap::new();
for (id, meta) in nodes {
let id_str = id.as_str();
let inputs = meta.spec.inputs();
indeg.insert(id_str, inputs.len());
for input in inputs {
dependents
.entry(input.as_str())
.or_default()
.entry(id_str)
.and_modify(|count| *count += 1)
.or_insert(1);
}
}
let mut q: Vec<&str> = indeg
.iter()
.filter_map(|(k, &v)| if v == 0 { Some(*k) } else { None })
.collect();
let mut out: Vec<FieldId> = Vec::new();
while let Some(n) = q.pop() {
out.push(n.to_string());
if let Some(children) = dependents.get(n) {
for (child, count) in children {
if let Some(e) = indeg.get_mut(child) {
*e = e.saturating_sub(*count);
if *e == 0 {
q.push(child);
}
}
}
}
}
if out.len() != nodes.len() {
return Err(Error::Compile("Cycle detected or missing nodes".into()));
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fieldgraph::node::{PowParams, ScaleParams};
use crate::prelude::{FieldSemantics, NodeSpec};
#[test]
fn compile_orders_nodes_topologically() {
let mut spec = FieldGraphSpec::default();
spec.add("a", NodeSpec::constant(1.0));
spec.add("b", NodeSpec::add(vec!["a".into()]));
spec.add_with_semantics(
"prob",
NodeSpec::mul(vec!["a".into(), "b".into()]),
FieldSemantics::Probability,
);
let opts = CompileOptions::default();
let program = FieldGraphCompiler::compile(&spec, &opts).expect("compile succeeds");
assert_eq!(program.nodes.len(), 3);
let pos_a = program.topo.iter().position(|id| id == "a").unwrap();
let pos_b = program.topo.iter().position(|id| id == "b").unwrap();
let pos_prob = program.topo.iter().position(|id| id == "prob").unwrap();
assert!(pos_a < pos_b && pos_b < pos_prob);
}
#[test]
fn compile_detects_unknown_inputs() {
let mut spec = FieldGraphSpec::default();
spec.add("bad", NodeSpec::add(vec!["missing".into()]));
let err = FieldGraphCompiler::compile(&spec, &CompileOptions::default())
.expect_err("expected compile failure");
matches!(err, Error::Compile(_))
.then_some(())
.expect("compile error");
}
#[test]
fn compile_rejects_nodes_with_missing_inputs() {
let mut spec = FieldGraphSpec::default();
spec.add("a", NodeSpec::constant(1.0));
spec.add("bad_min", NodeSpec::min(Vec::new()));
let err = FieldGraphCompiler::compile(&spec, &CompileOptions::default())
.expect_err("missing inputs should fail");
matches!(err, Error::Compile(_))
.then_some(())
.expect("compile error");
}
#[test]
fn compile_rejects_nodes_with_extra_inputs() {
let mut spec = FieldGraphSpec::default();
spec.add("a", NodeSpec::constant(1.0));
spec.add(
"bad_scale",
NodeSpec::Scale {
inputs: vec!["a".into(), "a".into()],
params: ScaleParams { factor: 2.0 },
},
);
spec.add(
"bad_pow",
NodeSpec::Pow {
inputs: vec![],
params: PowParams { exp: 2.0 },
},
);
let err = FieldGraphCompiler::compile(&spec, &CompileOptions::default())
.expect_err("extra or missing inputs should fail");
matches!(err, Error::Compile(_))
.then_some(())
.expect("compile error");
}
#[test]
fn compile_detects_cycles() {
let mut spec = FieldGraphSpec::default();
spec.add("a", NodeSpec::add(vec!["b".into()]));
spec.add("b", NodeSpec::add(vec!["a".into()]));
let err = FieldGraphCompiler::compile(&spec, &CompileOptions::default())
.expect_err("cycle should fail");
matches!(err, Error::Compile(_))
.then_some(())
.expect("compile error");
}
#[test]
fn compile_marks_force_bake_nodes() {
let mut spec = FieldGraphSpec::default();
spec.add("base", NodeSpec::constant(0.0));
spec.add("baked", NodeSpec::scale("base".into(), 2.0));
let mut opts = CompileOptions::default();
opts.force_bake.insert("baked".into());
let program = FieldGraphCompiler::compile(&spec, &opts).expect("compile succeeds");
assert!(program.nodes.get("baked").expect("node exists").force_bake);
}
#[test]
fn compile_handles_duplicate_inputs() {
let mut spec = FieldGraphSpec::default();
spec.add("a", NodeSpec::constant(1.0));
spec.add("square", NodeSpec::mul(vec!["a".into(), "a".into()]));
let program = FieldGraphCompiler::compile(&spec, &CompileOptions::default())
.expect("compile succeeds");
assert_eq!(program.topo.len(), 2);
assert!(program.topo.iter().any(|f| f == "square"));
}
}