vyre_driver/backend/
validation.rs1use super::capability::Backend;
4use std::sync::Arc;
5use vyre_foundation::ir::model::node::Node;
6use vyre_foundation::ir::{OpId, Program, ValidationError};
7
8const CORE_SUPPORTED_OP_IDS: &[&str] = &[
9 "vyre.node.let",
10 "vyre.node.assign",
11 "vyre.node.store",
12 "vyre.node.if",
13 "vyre.node.loop",
14 "vyre.node.return",
15 "vyre.node.block",
16 "vyre.node.barrier",
17 "vyre.node.indirect_dispatch",
18 "vyre.node.async_load",
19 "vyre.node.async_wait",
20 "vyre.node.region",
21 "vyre.lit_u32",
22 "vyre.lit_i32",
23 "vyre.lit_f32",
24 "vyre.lit_bool",
25 "vyre.var",
26 "vyre.bin_op",
27 "vyre.un_op",
28 "vyre.load",
29 "vyre.store",
30];
31
32pub fn validate_program(program: &Program, backend: &dyn Backend) -> Result<(), ValidationError> {
34 for (index, node) in program.entry().iter().enumerate() {
35 validate_node(node, index, backend.id(), backend.supported_ops())?;
36 }
37 Ok(())
38}
39
40pub fn default_supported_ops() -> &'static std::collections::HashSet<OpId> {
42 static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
43 OPS.get_or_init(|| {
44 let mut ops = std::collections::HashSet::new();
45 let _ = ops.try_reserve(CORE_SUPPORTED_OP_IDS.len());
46 ops.extend(CORE_SUPPORTED_OP_IDS.iter().copied().map(Arc::<str>::from));
47 ops
48 })
49}
50
51pub fn default_supported_ops_with_trap() -> &'static std::collections::HashSet<OpId> {
57 static OPS: std::sync::OnceLock<std::collections::HashSet<OpId>> = std::sync::OnceLock::new();
58 OPS.get_or_init(|| {
59 let base = default_supported_ops();
60 let reserve = base.len().saturating_add(1);
61 let mut ops = std::collections::HashSet::new();
62 let _ = ops.try_reserve(reserve);
63 ops.extend(base.iter().cloned());
64 ops.insert(Arc::<str>::from("vyre.node.trap"));
65 ops
66 })
67}
68
69fn validate_node(
70 node: &Node,
71 index: usize,
72 backend: &'static str,
73 supported: &std::collections::HashSet<OpId>,
74) -> Result<(), ValidationError> {
75 let op = node_op_id(node);
76 if !supported.contains(op) {
77 let op_id = Arc::<str>::from(op);
78 return Err(ValidationError::unsupported_op(backend, &op_id, index));
79 }
80 match node {
81 Node::If {
82 then, otherwise, ..
83 } => {
84 for (offset, nested) in then.iter().enumerate() {
85 validate_node(nested, offset, backend, supported)?;
86 }
87 for (offset, nested) in otherwise.iter().enumerate() {
88 validate_node(nested, offset, backend, supported)?;
89 }
90 }
91 Node::Loop { body, .. } | Node::Block(body) => {
92 for (offset, nested) in body.iter().enumerate() {
93 validate_node(nested, offset, backend, supported)?;
94 }
95 }
96 Node::Region { body, .. } => {
97 for (offset, nested) in body.iter().enumerate() {
98 validate_node(nested, offset, backend, supported)?;
99 }
100 }
101 Node::Let { .. }
104 | Node::Assign { .. }
105 | Node::Store { .. }
106 | Node::Return
107 | Node::Barrier { .. }
108 | Node::IndirectDispatch { .. }
109 | Node::AsyncLoad { .. }
110 | Node::AsyncWait { .. }
111 | Node::Opaque(_) => {}
112 _ => {}
115 }
116 Ok(())
117}
118
119#[must_use]
121pub fn node_op_id(node: &Node) -> &'static str {
122 match node {
123 Node::Let { .. } => "vyre.node.let",
124 Node::Assign { .. } => "vyre.node.assign",
125 Node::Store { .. } => "vyre.node.store",
126 Node::If { .. } => "vyre.node.if",
127 Node::Loop { .. } => "vyre.node.loop",
128 Node::Return => "vyre.node.return",
129 Node::Block(_) => "vyre.node.block",
130 Node::Barrier { .. } => "vyre.node.barrier",
131 Node::IndirectDispatch { .. } => "vyre.node.indirect_dispatch",
132 Node::AsyncLoad { .. } => "vyre.node.async_load",
133 Node::AsyncWait { .. } => "vyre.node.async_wait",
134 Node::Trap { .. } => "vyre.node.trap",
135 Node::Resume { .. } => "vyre.node.resume",
136 Node::AllReduce { .. } => "vyre.node.all_reduce",
137 Node::AllGather { .. } => "vyre.node.all_gather",
138 Node::ReduceScatter { .. } => "vyre.node.reduce_scatter",
139 Node::Broadcast { .. } => "vyre.node.broadcast",
140 Node::Region { .. } => "vyre.node.region",
146 Node::Opaque(extension) => extension.extension_kind(),
147 _ => "vyre.node.unknown",
150 }
151}