1use std::collections::HashMap;
9
10use vyre::ir::{BufferAccess, BufferDecl, Program};
11
12use vyre::Error;
13
14use crate::{
15 eval_node,
16 oob::Buffer,
17 value::Value,
18 workgroup::{self, Invocation, Memory},
19};
20
21pub fn run(program: &Program, inputs: &[Value]) -> Result<Vec<Value>, vyre::Error> {
35 let validation_errors = vyre::ir::validate(program);
36 if !validation_errors.is_empty() {
37 let messages = validation_errors
38 .into_iter()
39 .map(|error| error.message().to_string())
40 .collect::<Vec<_>>()
41 .join("; ");
42 return Err(Error::interp(format!(
43 "program failed IR validation: {messages}. Fix: repair the Program before invoking the reference interpreter."
44 )));
45 }
46
47 let Prepared {
48 storage,
49 output_names,
50 max_elements,
51 } = prepare_storage(program, inputs)?;
52 execute_dispatch(program, storage, output_names, max_elements)
53}
54
55struct Prepared {
56 storage: HashMap<String, Buffer>,
57 output_names: Vec<String>,
58 max_elements: u32,
59}
60
61fn prepare_storage(program: &Program, inputs: &[Value]) -> Result<Prepared, vyre::Error> {
62 let mut storage = HashMap::new();
63 let mut input_index = 0usize;
64 let mut output_names = Vec::new();
65 let mut max_elements = 1u32;
66
67 for decl in program.buffers() {
68 if decl.access() == BufferAccess::Workgroup {
69 continue;
70 }
71 let value = inputs
72 .get(input_index)
73 .ok_or_else(|| Error::interp(format!(
74 "missing input for buffer `{}`. Fix: pass one Value for each non-workgroup buffer in Program::buffers order.",
75 decl.name()
76 )))?;
77 input_index += 1;
78
79 let bytes = value.to_bytes();
80 max_elements = max_elements.max(element_count(decl, bytes.len())?);
81 if decl.access() == BufferAccess::ReadWrite {
82 output_names.push(decl.name().to_string());
83 }
84 storage.insert(
85 decl.name().to_string(),
86 Buffer {
87 bytes,
88 element: decl.element(),
89 },
90 );
91 }
92
93 if input_index != inputs.len() {
94 return Err(Error::interp(
95 "unused input values supplied. Fix: pass exactly one Value per non-workgroup buffer declaration.",
96 ));
97 }
98
99 Ok(Prepared {
100 storage,
101 output_names,
102 max_elements,
103 })
104}
105
106fn execute_dispatch(
107 program: &Program,
108 mut storage: HashMap<String, Buffer>,
109 output_names: Vec<String>,
110 max_elements: u32,
111) -> Result<Vec<Value>, vyre::Error> {
112 validate_workgroup_size(program)?;
113 let invocations_per_workgroup = invocations_per_workgroup(program);
114 let workgroup_count_x = max_elements.div_ceil(invocations_per_workgroup).max(1);
115
116 for wg_x in 0..workgroup_count_x {
117 run_workgroup(program, &mut storage, [wg_x, 0, 0])?;
118 }
119
120 output_names
121 .into_iter()
122 .map(|name| {
123 storage
124 .remove(&name)
125 .map(|buffer| Value::Bytes(buffer.bytes))
126 .ok_or_else(|| Error::interp(format!(
127 "missing output buffer `{name}` after dispatch. Fix: keep buffer declarations unique."
128 )))
129 })
130 .collect()
131}
132
133fn validate_workgroup_size(program: &Program) -> Result<(), vyre::Error> {
134 if program.workgroup_size().contains(&0) {
135 return Err(Error::interp(
136 "workgroup size contains zero. Fix: all dimensions must be >= 1.",
137 ));
138 }
139 Ok(())
140}
141
142fn invocations_per_workgroup(program: &Program) -> u32 {
143 program
144 .workgroup_size()
145 .iter()
146 .copied()
147 .fold(1u32, u32::saturating_mul)
148 .max(1)
149}
150
151fn run_workgroup(
152 program: &Program,
153 storage: &mut HashMap<String, Buffer>,
154 workgroup_id: [u32; 3],
155) -> Result<(), vyre::Error> {
156 let mut memory = Memory {
157 storage: std::mem::take(storage),
158 workgroup: workgroup::workgroup_memory(program)?,
159 };
160 let mut invocations = workgroup::create_invocations(program, workgroup_id)?;
161 run_invocations(program, &mut memory, &mut invocations)?;
162 *storage = memory.storage;
163 Ok(())
164}
165
166fn run_invocations<'a>(
167 program: &'a Program,
168 memory: &mut Memory,
169 invocations: &mut [Invocation<'a>],
170) -> Result<(), vyre::Error> {
171 while invocations.iter().any(|invocation| !invocation.done()) {
172 let made_progress = step_round_robin(program, memory, invocations)?;
173 verify_uniform_control_flow(invocations)?;
174 if release_barrier_if_ready(invocations) {
175 continue;
176 }
177 if !made_progress && live_waiting_count(invocations) > 0 {
178 return Err(Error::interp(
179 "program violates uniform-control-flow rule: not every live invocation reached the same barrier. Fix: move Barrier to uniform control flow.",
180 ));
181 }
182 }
183 Ok(())
184}
185
186fn step_round_robin<'a>(
187 program: &'a Program,
188 memory: &mut Memory,
189 invocations: &mut [Invocation<'a>],
190) -> Result<bool, vyre::Error> {
191 let mut made_progress = false;
192 for invocation in invocations {
193 if invocation.done() || invocation.waiting_at_barrier {
194 continue;
195 }
196 eval_node::step(invocation, memory, program)?;
197 made_progress = true;
198 }
199 Ok(made_progress)
200}
201
202fn release_barrier_if_ready(invocations: &mut [Invocation<'_>]) -> bool {
203 let active = invocations
204 .iter()
205 .filter(|invocation| !invocation.done())
206 .count();
207 let waiting = live_waiting_count(invocations);
208 if active > 0 && active == waiting {
209 for invocation in invocations {
210 invocation.waiting_at_barrier = false;
211 }
212 true
213 } else {
214 false
215 }
216}
217
218fn live_waiting_count(invocations: &[Invocation<'_>]) -> usize {
219 invocations
220 .iter()
221 .filter(|invocation| !invocation.done() && invocation.waiting_at_barrier)
222 .count()
223}
224
225fn verify_uniform_control_flow(invocations: &[Invocation<'_>]) -> Result<(), vyre::Error> {
226 let mut observed: HashMap<usize, bool> = HashMap::new();
235 for invocation in invocations.iter().filter(|invocation| !invocation.done()) {
236 for (id, value) in &invocation.uniform_checks {
237 if let Some(previous) = observed.insert(*id, *value) {
238 if previous != *value {
239 return Err(Error::interp(
240 "program violates uniform-control-flow rule: Barrier appears inside an If whose condition differs across the workgroup. Fix: make the condition uniform or move Barrier outside the branch.",
241 ));
242 }
243 }
244 }
245 }
246 Ok(())
247}
248
249fn element_count(decl: &BufferDecl, byte_len: usize) -> Result<u32, vyre::Error> {
250 let stride = decl.element().min_bytes();
251 if stride == 0 {
252 return u32::try_from(byte_len).map_err(|_| Error::interp(format!(
253 "buffer `{}` has {} bytes and cannot be indexed within u32 address space. Fix: shrink or split the invocation."
254 , decl.name(),
255 byte_len,
256 )));
257 }
258 let elements = byte_len / stride;
259 u32::try_from(elements).map_err(|_| Error::interp(format!(
260 "buffer `{}` has {} bytes for stride {} and overflows u32 elements. Fix: shrink declaration footprint or split work.",
261 decl.name(),
262 byte_len,
263 stride,
264 )))
265}