1use super::{ComputeGraph, TensorNode, TensorOp};
21
22#[derive(Debug)]
24pub struct GraphExecResult {
25 pub n_launches: usize,
27 pub elapsed_us: Option<u64>,
29}
30
31pub trait KernelDispatch {
37 fn dispatch_mul_mat(
47 &mut self,
48 node: &TensorNode,
49 input_ptr: u64,
50 output_ptr: u64,
51 m: u32,
52 n: u32,
53 k: u32,
54 ) -> Result<(), crate::GpuError>;
55
56 fn dispatch_rms_norm(
58 &mut self,
59 node: &TensorNode,
60 input_ptr: u64,
61 output_ptr: u64,
62 hidden_dim: u32,
63 m: u32,
64 epsilon: f32,
65 ) -> Result<(), crate::GpuError>;
66
67 fn dispatch_add(
69 &mut self,
70 a_ptr: u64,
71 b_ptr: u64,
72 output_ptr: u64,
73 n_elements: usize,
74 ) -> Result<(), crate::GpuError>;
75
76 fn dispatch_rope(
78 &mut self,
79 node: &TensorNode,
80 qk_ptr: u64,
81 positions: &[u32],
82 head_dim: u32,
83 num_heads: u32,
84 ) -> Result<(), crate::GpuError>;
85
86 fn dispatch_attention(
88 &mut self,
89 node: &TensorNode,
90 q_ptr: u64,
91 k_ptr: u64,
92 v_ptr: u64,
93 output_ptr: u64,
94 m: u32,
95 layer_idx: usize,
96 ) -> Result<(), crate::GpuError>;
97
98 fn dispatch_copy(
100 &mut self,
101 src_ptr: u64,
102 dst_ptr: u64,
103 size_bytes: usize,
104 ) -> Result<(), crate::GpuError>;
105
106 fn dispatch_mul(
108 &mut self,
109 a_ptr: u64,
110 b_ptr: u64,
111 output_ptr: u64,
112 n_elements: usize,
113 ) -> Result<(), crate::GpuError>;
114
115 fn dispatch_silu(
117 &mut self,
118 input_ptr: u64,
119 output_ptr: u64,
120 n_elements: usize,
121 ) -> Result<(), crate::GpuError>;
122}
123
124pub fn execute_graph<D: KernelDispatch>(
131 graph: &ComputeGraph,
132 dispatcher: &mut D,
133) -> Result<usize, crate::GpuError> {
134 let mut n_launches = 0;
135
136 for node in &graph.nodes {
137 match node.op {
138 TensorOp::None => {
139 }
141 TensorOp::MulMat => {
142 let input_idx = node.inputs.first().copied().unwrap_or(0);
143 let input_ptr = graph.nodes[input_idx].data_ptr;
144 dispatcher.dispatch_mul_mat(
145 node,
146 input_ptr,
147 node.data_ptr,
148 node.shape[2], node.shape[0], node.shape[1], )?;
152 n_launches += 1;
153 }
154 TensorOp::RmsNorm => {
155 let input_idx = node.inputs.first().copied().unwrap_or(0);
156 let input_ptr = graph.nodes[input_idx].data_ptr;
157 dispatcher.dispatch_rms_norm(
158 node,
159 input_ptr,
160 node.data_ptr,
161 node.shape[0], node.shape[2], node.params.scalar, )?;
165 n_launches += 1;
166 }
167 TensorOp::Add => {
168 let a_idx = node.inputs.first().copied().unwrap_or(0);
169 let b_idx = node.inputs.get(1).copied().unwrap_or(0);
170 let a_ptr = graph.nodes[a_idx].data_ptr;
171 let b_ptr = graph.nodes[b_idx].data_ptr;
172 let n_elements = (node.shape[0] * node.shape[2]) as usize;
173 dispatcher.dispatch_add(a_ptr, b_ptr, node.data_ptr, n_elements)?;
174 n_launches += 1;
175 }
176 TensorOp::Rope => {
177 let input_idx = node.inputs.first().copied().unwrap_or(0);
178 let input_ptr = graph.nodes[input_idx].data_ptr;
179 dispatcher.dispatch_rope(
181 node,
182 input_ptr,
183 &[], node.shape[0], node.shape[1], )?;
187 n_launches += 1;
188 }
189 TensorOp::SoftMax => {
190 let q_idx = node.inputs.first().copied().unwrap_or(0);
193 let k_idx = node.inputs.get(1).copied().unwrap_or(0);
194 let v_idx = node.inputs.get(2).copied().unwrap_or(0);
195 dispatcher.dispatch_attention(
196 node,
197 graph.nodes[q_idx].data_ptr,
198 graph.nodes[k_idx].data_ptr,
199 graph.nodes[v_idx].data_ptr,
200 node.data_ptr,
201 node.shape[2], node.params.int_param as usize, )?;
204 n_launches += 1;
205 }
206 TensorOp::Copy => {
207 let src_idx = node.inputs.first().copied().unwrap_or(0);
208 let src_ptr = graph.nodes[src_idx].data_ptr;
209 let size = (node.shape[0] * node.shape[1] * 4) as usize; dispatcher.dispatch_copy(src_ptr, node.data_ptr, size)?;
211 n_launches += 1;
212 }
213 TensorOp::Mul => {
214 let a_idx = node.inputs.first().copied().unwrap_or(0);
215 let b_idx = node.inputs.get(1).copied().unwrap_or(0);
216 let n_elements = (node.shape[0] * node.shape[2]) as usize;
217 dispatcher.dispatch_mul(
218 graph.nodes[a_idx].data_ptr,
219 graph.nodes[b_idx].data_ptr,
220 node.data_ptr,
221 n_elements,
222 )?;
223 n_launches += 1;
224 }
225 TensorOp::Silu => {
226 let input_idx = node.inputs.first().copied().unwrap_or(0);
227 let n_elements = (node.shape[0] * node.shape[2]) as usize;
228 dispatcher.dispatch_silu(
229 graph.nodes[input_idx].data_ptr,
230 node.data_ptr,
231 n_elements,
232 )?;
233 n_launches += 1;
234 }
235 }
236 }
237
238 Ok(n_launches)
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 struct CountingDispatcher {
247 launches: usize,
248 }
249
250 impl KernelDispatch for CountingDispatcher {
251 fn dispatch_mul_mat(
252 &mut self,
253 _: &TensorNode,
254 _: u64,
255 _: u64,
256 _: u32,
257 _: u32,
258 _: u32,
259 ) -> Result<(), crate::GpuError> {
260 self.launches += 1;
261 Ok(())
262 }
263 fn dispatch_rms_norm(
264 &mut self,
265 _: &TensorNode,
266 _: u64,
267 _: u64,
268 _: u32,
269 _: u32,
270 _: f32,
271 ) -> Result<(), crate::GpuError> {
272 self.launches += 1;
273 Ok(())
274 }
275 fn dispatch_add(
276 &mut self,
277 _: u64,
278 _: u64,
279 _: u64,
280 _: usize,
281 ) -> Result<(), crate::GpuError> {
282 self.launches += 1;
283 Ok(())
284 }
285 fn dispatch_rope(
286 &mut self,
287 _: &TensorNode,
288 _: u64,
289 _: &[u32],
290 _: u32,
291 _: u32,
292 ) -> Result<(), crate::GpuError> {
293 self.launches += 1;
294 Ok(())
295 }
296 fn dispatch_attention(
297 &mut self,
298 _: &TensorNode,
299 _: u64,
300 _: u64,
301 _: u64,
302 _: u64,
303 _: u32,
304 _: usize,
305 ) -> Result<(), crate::GpuError> {
306 self.launches += 1;
307 Ok(())
308 }
309 fn dispatch_copy(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
310 self.launches += 1;
311 Ok(())
312 }
313 fn dispatch_mul(
314 &mut self,
315 _: u64,
316 _: u64,
317 _: u64,
318 _: usize,
319 ) -> Result<(), crate::GpuError> {
320 self.launches += 1;
321 Ok(())
322 }
323 fn dispatch_silu(&mut self, _: u64, _: u64, _: usize) -> Result<(), crate::GpuError> {
324 self.launches += 1;
325 Ok(())
326 }
327 }
328
329 #[test]
330 fn test_execute_empty_graph() {
331 let g = ComputeGraph::new();
332 let mut d = CountingDispatcher { launches: 0 };
333 let n = execute_graph(&g, &mut d).unwrap();
334 assert_eq!(n, 0);
335 assert_eq!(d.launches, 0);
336 }
337
338 #[test]
339 fn test_execute_single_layer_graph() {
340 use super::super::OpParams;
341
342 let mut g = ComputeGraph::new();
343
344 let input = g.add_leaf(0x1000, [1536, 1, 4, 0]);
347 let normed = g.add_op(
348 TensorOp::RmsNorm,
349 0x2000,
350 [1536, 1, 4, 0],
351 vec![input],
352 OpParams {
353 gamma_ptr: 0x3000,
354 scalar: 1e-6,
355 ..Default::default()
356 },
357 );
358 let q = g.add_op(
359 TensorOp::MulMat,
360 0x4000,
361 [1536, 1536, 4, 0],
362 vec![normed],
363 OpParams {
364 weight_ptr: 0x5000,
365 ..Default::default()
366 },
367 );
368 let k = g.add_op(
369 TensorOp::MulMat,
370 0x6000,
371 [256, 1536, 4, 0],
372 vec![normed],
373 OpParams {
374 weight_ptr: 0x7000,
375 ..Default::default()
376 },
377 );
378 let v = g.add_op(
379 TensorOp::MulMat,
380 0x8000,
381 [256, 1536, 4, 0],
382 vec![normed],
383 OpParams {
384 weight_ptr: 0x9000,
385 ..Default::default()
386 },
387 );
388 let attn = g.add_op(
389 TensorOp::SoftMax,
390 0xA000,
391 [1536, 1, 4, 0],
392 vec![q, k, v],
393 OpParams {
394 int_param: 0,
395 ..Default::default()
396 },
397 );
398 let _residual = g.add_op(
399 TensorOp::Add,
400 0xB000,
401 [1536, 1, 4, 0],
402 vec![input, attn],
403 OpParams::default(),
404 );
405
406 let mut d = CountingDispatcher { launches: 0 };
407 let n = execute_graph(&g, &mut d).unwrap();
408
409 assert_eq!(n, 6);
411 assert_eq!(d.launches, 6);
412 assert_eq!(g.n_ops(), 6);
413 }
414}