use std::fmt;
pub use crate::ir_inner::model::generated::Node;
pub trait NodeExtension: fmt::Debug + Send + Sync + 'static {
fn extension_kind(&self) -> &'static str;
fn debug_identity(&self) -> &str;
fn stable_fingerprint(&self) -> [u8; 32];
fn validate_extension(&self) -> Result<(), String>;
fn as_any(&self) -> &dyn std::any::Any;
fn wire_payload(&self) -> Vec<u8> {
Vec::new()
}
}
mod impl_node;
#[must_use]
pub fn node_op_id(node: &Node) -> &'static str {
match node {
Node::Let { .. } => "vyre.node.let",
Node::Assign { .. } => "vyre.node.assign",
Node::Store { .. } => "vyre.node.store",
Node::If { .. } => "vyre.node.if",
Node::Loop { .. } => "vyre.node.loop",
Node::Return => "vyre.node.return",
Node::Block(_) => "vyre.node.block",
Node::Barrier { .. } => "vyre.node.barrier",
Node::Region { .. } => "vyre.node.region",
Node::IndirectDispatch { .. } => "vyre.node.indirect_dispatch",
Node::AsyncLoad { .. } => "vyre.node.async_load",
Node::AsyncStore { .. } => "vyre.node.async_store",
Node::AsyncWait { .. } => "vyre.node.async_wait",
Node::Trap { .. } => "vyre.node.trap",
Node::Resume { .. } => "vyre.node.resume",
Node::Opaque(extension) => extension.extension_kind(),
}
}
#[cfg(test)]
pub(crate) mod tests {
use crate::ir::{BufferAccess, BufferDecl, DataType, Node, Program};
#[test]
fn indirect_dispatch_round_trip() {
let program = Program::wrapped(
vec![BufferDecl::storage(
"counts",
0,
BufferAccess::ReadOnly,
DataType::U32,
)],
[64, 1, 1],
vec![Node::indirect_dispatch("counts", 16)],
);
let wire = program
.to_wire()
.expect("Fix: indirect dispatch must serialize into VIR0");
let decoded =
Program::from_wire(&wire).expect("Fix: indirect dispatch must decode from VIR0");
assert_eq!(decoded, program);
}
#[test]
fn async_load_async_wait_round_trip() {
let program = Program::wrapped(
vec![BufferDecl::storage(
"out",
0,
BufferAccess::ReadWrite,
DataType::U32,
)],
[1, 1, 1],
vec![
Node::async_load("chunk-0"),
Node::store("out", crate::ir::Expr::u32(0), crate::ir::Expr::u32(1)),
Node::async_wait("chunk-0"),
],
);
let wire = program
.to_wire()
.expect("Fix: async stream nodes must serialize into VIR0");
let decoded =
Program::from_wire(&wire).expect("Fix: async stream nodes must decode from VIR0");
assert_eq!(decoded, program);
}
}