use ezu_graph::{
schema_frag, take_input_ref, BuiltNode, Connection, EvalCtx, EvalError, FactoryCtx,
FactoryError, Node, NodeFactory, PortKind, PortSpec, PortValue,
};
use serde_json::Value;
use xxhash_rust::xxh3::Xxh3;
use crate::nodes::common::read_optional_string;
const ACCEPTS_ANY: &[PortKind] = &[
PortKind::Features,
PortKind::Raster,
PortKind::Sprite,
PortKind::Brush,
PortKind::Scalar,
PortKind::ScalarField,
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Select {
A,
B,
}
struct SwitchNode {
select: Select,
}
impl Node for SwitchNode {
fn op_name(&self) -> &'static str {
"switch"
}
fn inputs(&self) -> &[PortSpec] {
static SPECS: &[PortSpec] = &[
PortSpec {
name: "a",
accepts: ACCEPTS_ANY,
optional: false,
},
PortSpec {
name: "b",
accepts: ACCEPTS_ANY,
optional: false,
},
];
SPECS
}
fn output(&self, input_kinds: &[Option<PortKind>]) -> PortKind {
let idx = match self.select {
Select::A => 0,
Select::B => 1,
};
input_kinds[idx].unwrap_or(PortKind::Raster)
}
fn eval(
&self,
_ctx: &EvalCtx<'_>,
inputs: &[Option<PortValue>],
) -> Result<PortValue, EvalError> {
let idx = match self.select {
Select::A => 0,
Select::B => 1,
};
inputs[idx]
.clone()
.ok_or_else(|| EvalError::MissingInput(if idx == 0 { "a" } else { "b" }.into()))
}
fn param_hash(&self, h: &mut Xxh3) {
h.update(b"switch");
h.update(&[matches!(self.select, Select::B) as u8]);
}
}
pub(super) struct SwitchFactory;
impl NodeFactory for SwitchFactory {
fn op_name(&self) -> &'static str {
"switch"
}
fn build(
&self,
fields: &serde_json::Map<String, Value>,
ctx: &FactoryCtx<'_>,
) -> Result<BuiltNode, FactoryError> {
let a = take_input_ref(fields, "a")?;
let b = take_input_ref(fields, "b")?;
let select = match read_optional_string(fields, "select")?.as_deref() {
None | Some("a") => Select::A,
Some("b") => Select::B,
Some(other) => {
let raw = fields.get("select");
if let Some(v) = raw {
if let Some(b_val) = v.as_bool() {
if b_val {
Select::B
} else {
Select::A
}
} else if let Some(n) = v.as_u64() {
if n == 0 {
Select::A
} else {
Select::B
}
} else {
return Err(FactoryError::BadField {
field: "select".into(),
msg: format!("expected `a`/`b` (or bool / 0/1), got `{other}`"),
});
}
} else {
return Err(FactoryError::BadField {
field: "select".into(),
msg: format!("expected `a` or `b`, got `{other}`"),
});
}
}
};
let _ = ctx;
Ok(BuiltNode {
node: Box::new(SwitchNode { select }),
connections: vec![
Connection {
port: "a".into(),
src: a,
},
Connection {
port: "b".into(),
src: b,
},
],
})
}
fn schema(&self) -> Value {
serde_json::json!({
"description": "Pick `a` or `b` based on `select` (default `a`). Both inputs accept any port kind; the output's kind mirrors the selected input. Useful for A/B comparison and param-driven branching.",
"properties": {
"a": schema_frag::node_ref(),
"b": schema_frag::node_ref(),
"select": {
"oneOf": [
{ "type": "string", "enum": ["a", "b"] },
{ "type": "boolean" },
{ "type": "integer", "minimum": 0, "maximum": 1 },
],
"default": "a",
},
},
"required": ["a", "b"],
})
}
}
ezu_graph::submit_node!(SwitchFactory);