use super::capability::Backend;
use std::sync::Arc;
use vyre_foundation::ir::model::node::Node;
use vyre_foundation::ir::{OpId, Program, ValidationError};
pub fn validate_program(program: &Program, backend: &dyn Backend) -> Result<(), ValidationError> {
for (index, node) in program.entry().iter().enumerate() {
validate_node(node, index, backend.id(), backend.supported_ops())?;
}
Ok(())
}
pub fn default_supported_ops() -> &'static std::collections::HashSet<OpId> {
static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
OPS.get_or_init(|| {
[
"vyre.node.let",
"vyre.node.assign",
"vyre.node.store",
"vyre.node.if",
"vyre.node.loop",
"vyre.node.return",
"vyre.node.block",
"vyre.node.barrier",
"vyre.node.indirect_dispatch",
"vyre.node.async_load",
"vyre.node.async_wait",
"vyre.node.region",
"vyre.lit_u32",
"vyre.lit_i32",
"vyre.lit_f32",
"vyre.lit_bool",
"vyre.var",
"vyre.bin_op",
"vyre.un_op",
"vyre.load",
"vyre.store",
]
.into_iter()
.map(Arc::<str>::from)
.collect()
})
}
pub fn default_supported_ops_with_trap() -> &'static std::collections::HashSet<OpId> {
static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
OPS.get_or_init(|| {
let mut ops = default_supported_ops().clone();
ops.insert(Arc::<str>::from("vyre.node.trap"));
ops
})
}
fn validate_node(
node: &Node,
index: usize,
backend: &'static str,
supported: &std::collections::HashSet<OpId>,
) -> Result<(), ValidationError> {
let op = node_op_id(node);
if !supported.contains(op) {
return Err(ValidationError::unsupported_op(
backend,
Arc::from(op),
index,
));
}
match node {
Node::If {
then, otherwise, ..
} => {
for (offset, nested) in then.iter().enumerate() {
validate_node(nested, offset, backend, supported)?;
}
for (offset, nested) in otherwise.iter().enumerate() {
validate_node(nested, offset, backend, supported)?;
}
}
Node::Loop { body, .. } | Node::Block(body) => {
for (offset, nested) in body.iter().enumerate() {
validate_node(nested, offset, backend, supported)?;
}
}
Node::Let { .. }
| Node::Assign { .. }
| Node::Store { .. }
| Node::Return
| Node::Barrier { .. }
| Node::IndirectDispatch { .. }
| Node::AsyncLoad { .. }
| Node::AsyncWait { .. }
| Node::Opaque(_) => {}
_ => {}
}
Ok(())
}
#[must_use]
pub fn node_op_id(node: &Node) -> &'static str {
match node {
Node::Let { .. } => "vyre.node.let",
Node::Assign { .. } => "vyre.node.assign",
Node::Store { .. } => "vyre.node.store",
Node::If { .. } => "vyre.node.if",
Node::Loop { .. } => "vyre.node.loop",
Node::Return => "vyre.node.return",
Node::Block(_) => "vyre.node.block",
Node::Barrier { .. } => "vyre.node.barrier",
Node::IndirectDispatch { .. } => "vyre.node.indirect_dispatch",
Node::AsyncLoad { .. } => "vyre.node.async_load",
Node::AsyncWait { .. } => "vyre.node.async_wait",
Node::Trap { .. } => "vyre.node.trap",
Node::Resume { .. } => "vyre.node.resume",
Node::Region { .. } => "vyre.node.region",
Node::Opaque(extension) => extension.extension_kind(),
_ => "vyre.node.unknown",
}
}