Skip to main content

wave_runtime/
launcher.rs

1// Copyright 2026 Ojima Abraham
2// SPDX-License-Identifier: Apache-2.0
3
4//! Kernel launch and synchronization for the WAVE runtime.
5//!
6//! Dispatches compiled kernels to the appropriate GPU vendor or the WAVE
7//! emulator. For vendor backends (Metal, CUDA, HIP, SYCL), generates a host
8//! program, compiles it via subprocess, and runs it. The emulator path calls
9//! `wave_emu` directly as a library.
10
11#![allow(clippy::format_push_string)]
12
13use crate::device::GpuVendor;
14use crate::error::RuntimeError;
15use crate::memory::DeviceBuffer;
16use std::fs;
17use std::path::Path;
18use std::process::Command;
19
20/// Launch a compiled kernel on the specified vendor.
21///
22/// # Errors
23///
24/// Returns `RuntimeError::Launch` if the kernel cannot be launched, compiled,
25/// or executed on the target device.
26#[allow(clippy::too_many_arguments)]
27pub fn launch_kernel(
28    vendor_code: &str,
29    wbin: &[u8],
30    vendor: GpuVendor,
31    buffers: &mut [&mut DeviceBuffer],
32    scalars: &[u32],
33    grid: [u32; 3],
34    workgroup: [u32; 3],
35) -> Result<(), RuntimeError> {
36    match vendor {
37        GpuVendor::Apple => launch_metal(vendor_code, buffers, scalars, grid, workgroup),
38        GpuVendor::Nvidia => launch_cuda(vendor_code, buffers, scalars, grid, workgroup),
39        GpuVendor::Amd => launch_hip(vendor_code, buffers, scalars, grid, workgroup),
40        GpuVendor::Intel => launch_sycl(vendor_code, buffers, scalars, grid, workgroup),
41        GpuVendor::Emulator => launch_emulator(wbin, buffers, scalars, grid, workgroup),
42    }
43}
44
45fn launch_emulator(
46    wbin: &[u8],
47    buffers: &mut [&mut DeviceBuffer],
48    scalars: &[u32],
49    grid: [u32; 3],
50    workgroup: [u32; 3],
51) -> Result<(), RuntimeError> {
52    let total_buffer_bytes: usize = buffers.iter().map(|b| b.size_bytes()).sum();
53    let device_mem_size = (total_buffer_bytes + 4096).max(1024 * 1024);
54
55    let mut config = wave_emu::EmulatorConfig {
56        grid_dim: grid,
57        workgroup_dim: workgroup,
58        device_memory_size: device_mem_size,
59        ..wave_emu::EmulatorConfig::default()
60    };
61
62    let mut initial_regs: Vec<(u8, u32)> = Vec::new();
63    let mut offset: u32 = 0;
64    for (i, buf) in buffers.iter().enumerate() {
65        #[allow(clippy::cast_possible_truncation)]
66        let reg_idx = i as u8;
67        initial_regs.push((reg_idx, offset));
68        #[allow(clippy::cast_possible_truncation)]
69        let size = buf.size_bytes() as u32;
70        offset += size;
71    }
72    for (i, &scalar) in scalars.iter().enumerate() {
73        #[allow(clippy::cast_possible_truncation)]
74        let reg_idx = (buffers.len() + i) as u8;
75        initial_regs.push((reg_idx, scalar));
76    }
77    config.initial_registers = initial_regs;
78
79    let mut emu = wave_emu::Emulator::new(config);
80    emu.load_binary(wbin)?;
81
82    let mut mem_offset: u64 = 0;
83    for buf in buffers.iter() {
84        emu.load_device_memory(mem_offset, &buf.data)?;
85        #[allow(clippy::cast_possible_truncation)]
86        let size = buf.size_bytes() as u64;
87        mem_offset += size;
88    }
89
90    emu.run()?;
91
92    let mut read_offset: u64 = 0;
93    for buf in buffers.iter_mut() {
94        let size = buf.size_bytes();
95        let result = emu.read_device_memory(read_offset, size)?;
96        buf.data = result;
97        read_offset += size as u64;
98    }
99
100    Ok(())
101}
102
103fn write_buffer_files(
104    dir: &Path,
105    buffers: &[&mut DeviceBuffer],
106) -> Result<Vec<String>, RuntimeError> {
107    let mut paths = Vec::new();
108    for (i, buf) in buffers.iter().enumerate() {
109        let path = dir.join(format!("buf_{i}.bin"));
110        fs::write(&path, &buf.data)?;
111        paths.push(
112            path.to_str()
113                .ok_or_else(|| RuntimeError::Io("invalid path".into()))?
114                .to_string(),
115        );
116    }
117    Ok(paths)
118}
119
120fn read_buffer_files(dir: &Path, buffers: &mut [&mut DeviceBuffer]) -> Result<(), RuntimeError> {
121    for (i, buf) in buffers.iter_mut().enumerate() {
122        let path = dir.join(format!("buf_{i}.bin"));
123        buf.data = fs::read(&path)?;
124    }
125    Ok(())
126}
127
128fn launch_metal(
129    vendor_code: &str,
130    buffers: &mut [&mut DeviceBuffer],
131    scalars: &[u32],
132    grid: [u32; 3],
133    workgroup: [u32; 3],
134) -> Result<(), RuntimeError> {
135    let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
136    let metal_path = dir.path().join("kernel.metal");
137    fs::write(&metal_path, vendor_code)?;
138
139    let buf_paths = write_buffer_files(dir.path(), buffers)?;
140
141    let host_src = generate_metal_host(&buf_paths, scalars, grid, workgroup);
142    let host_path = dir.path().join("host.swift");
143    fs::write(&host_path, &host_src)?;
144
145    let lib_path = dir.path().join("kernel.metallib");
146    let status = Command::new("xcrun")
147        .args(["-sdk", "macosx", "metal", "-o"])
148        .arg(&lib_path)
149        .arg(&metal_path)
150        .status()?;
151    if !status.success() {
152        return Err(RuntimeError::Launch(
153            "Metal shader compilation failed".into(),
154        ));
155    }
156
157    let exe_path = dir.path().join("host");
158    let status = Command::new("swiftc")
159        .arg("-o")
160        .arg(&exe_path)
161        .arg(&host_path)
162        .arg("-framework")
163        .arg("Metal")
164        .arg("-framework")
165        .arg("Foundation")
166        .status()?;
167    if !status.success() {
168        return Err(RuntimeError::Launch("Swift host compilation failed".into()));
169    }
170
171    let status = Command::new(&exe_path).arg(&lib_path).status()?;
172    if !status.success() {
173        return Err(RuntimeError::Launch("Metal kernel execution failed".into()));
174    }
175
176    read_buffer_files(dir.path(), buffers)?;
177    Ok(())
178}
179
180fn launch_cuda(
181    vendor_code: &str,
182    buffers: &mut [&mut DeviceBuffer],
183    scalars: &[u32],
184    grid: [u32; 3],
185    workgroup: [u32; 3],
186) -> Result<(), RuntimeError> {
187    let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
188
189    let host_src = generate_cuda_host(vendor_code, &[], scalars, grid, workgroup, buffers);
190    let cu_path = dir.path().join("kernel.cu");
191    fs::write(&cu_path, &host_src)?;
192
193    let buf_paths = write_buffer_files(dir.path(), buffers)?;
194
195    let exe_path = dir.path().join("kernel");
196    let status = Command::new("nvcc")
197        .arg("-o")
198        .arg(&exe_path)
199        .arg(&cu_path)
200        .status()?;
201    if !status.success() {
202        return Err(RuntimeError::Launch("CUDA compilation failed".into()));
203    }
204
205    let status = Command::new(&exe_path).args(&buf_paths).status()?;
206    if !status.success() {
207        return Err(RuntimeError::Launch("CUDA kernel execution failed".into()));
208    }
209
210    read_buffer_files(dir.path(), buffers)?;
211    Ok(())
212}
213
214fn launch_hip(
215    vendor_code: &str,
216    buffers: &mut [&mut DeviceBuffer],
217    scalars: &[u32],
218    grid: [u32; 3],
219    workgroup: [u32; 3],
220) -> Result<(), RuntimeError> {
221    let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
222
223    let host_src = generate_hip_host(vendor_code, &[], scalars, grid, workgroup, buffers);
224    let hip_path = dir.path().join("kernel.hip");
225    fs::write(&hip_path, &host_src)?;
226
227    let buf_paths = write_buffer_files(dir.path(), buffers)?;
228
229    let exe_path = dir.path().join("kernel");
230    let status = Command::new("hipcc")
231        .arg("-o")
232        .arg(&exe_path)
233        .arg(&hip_path)
234        .status()?;
235    if !status.success() {
236        return Err(RuntimeError::Launch("HIP compilation failed".into()));
237    }
238
239    let status = Command::new(&exe_path).args(&buf_paths).status()?;
240    if !status.success() {
241        return Err(RuntimeError::Launch("HIP kernel execution failed".into()));
242    }
243
244    read_buffer_files(dir.path(), buffers)?;
245    Ok(())
246}
247
248fn launch_sycl(
249    vendor_code: &str,
250    buffers: &mut [&mut DeviceBuffer],
251    scalars: &[u32],
252    grid: [u32; 3],
253    workgroup: [u32; 3],
254) -> Result<(), RuntimeError> {
255    let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
256
257    let host_src = generate_sycl_host(vendor_code, &[], scalars, grid, workgroup, buffers);
258    let cpp_path = dir.path().join("kernel.cpp");
259    fs::write(&cpp_path, &host_src)?;
260
261    let buf_paths = write_buffer_files(dir.path(), buffers)?;
262
263    let exe_path = dir.path().join("kernel");
264    let status = Command::new("icpx")
265        .arg("-fsycl")
266        .arg("-o")
267        .arg(&exe_path)
268        .arg(&cpp_path)
269        .status()?;
270    if !status.success() {
271        return Err(RuntimeError::Launch("SYCL compilation failed".into()));
272    }
273
274    let status = Command::new(&exe_path).args(&buf_paths).status()?;
275    if !status.success() {
276        return Err(RuntimeError::Launch("SYCL kernel execution failed".into()));
277    }
278
279    read_buffer_files(dir.path(), buffers)?;
280    Ok(())
281}
282
283fn generate_metal_host(
284    buf_paths: &[String],
285    scalars: &[u32],
286    grid: [u32; 3],
287    workgroup: [u32; 3],
288) -> String {
289    let mut src = String::from("import Metal\nimport Foundation\n\n");
290    src.push_str("let device = MTLCreateSystemDefaultDevice()!\n");
291    src.push_str("let lib = try! device.makeLibrary(filepath: CommandLine.arguments[1])\n");
292    src.push_str(
293        "let function = lib.functionNames.first.flatMap { lib.makeFunction(name: $0) }!\n",
294    );
295    src.push_str("let pipeline = try! device.makeComputePipelineState(function: function)\n");
296    src.push_str("let queue = device.makeCommandQueue()!\n");
297    src.push_str("let cmd = queue.makeCommandBuffer()!\n");
298    src.push_str("let enc = cmd.makeComputeCommandEncoder()!\n");
299    src.push_str("enc.setComputePipelineState(pipeline)\n\n");
300
301    for (i, path) in buf_paths.iter().enumerate() {
302        src.push_str(&format!(
303            "let data{i} = try! Data(contentsOf: URL(fileURLWithPath: \"{path}\"))\n"
304        ));
305        src.push_str(&format!(
306            "let buf{i} = device.makeBuffer(bytes: (data{i} as NSData).bytes, length: data{i}.count, options: .storageModeShared)!\n"
307        ));
308        src.push_str(&format!("enc.setBuffer(buf{i}, offset: 0, index: {i})\n"));
309    }
310
311    for (i, &s) in scalars.iter().enumerate() {
312        let idx = buf_paths.len() + i;
313        src.push_str(&format!("var scalar{i}: UInt32 = {s}\n"));
314        src.push_str(&format!(
315            "enc.setBytes(&scalar{i}, length: 4, index: {idx})\n"
316        ));
317    }
318
319    src.push_str(&format!(
320        "\nenc.dispatchThreadgroups(MTLSize(width: {}, height: {}, depth: {}), threadsPerThreadgroup: MTLSize(width: {}, height: {}, depth: {}))\n",
321        grid[0], grid[1], grid[2], workgroup[0], workgroup[1], workgroup[2]
322    ));
323    src.push_str("enc.endEncoding()\ncmd.commit()\ncmd.waitUntilCompleted()\n\n");
324
325    for (i, path) in buf_paths.iter().enumerate() {
326        src.push_str(&format!(
327            "let out{i} = Data(bytes: buf{i}.contents(), count: buf{i}.length)\n"
328        ));
329        src.push_str(&format!(
330            "try! out{i}.write(to: URL(fileURLWithPath: \"{path}\"))\n"
331        ));
332    }
333
334    src
335}
336
337#[allow(clippy::needless_pass_by_value)]
338fn generate_cuda_host(
339    kernel_code: &str,
340    _buf_paths: &[String],
341    scalars: &[u32],
342    grid: [u32; 3],
343    workgroup: [u32; 3],
344    buffers: &[&mut DeviceBuffer],
345) -> String {
346    let mut src = String::from("#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\n");
347    src.push_str(kernel_code);
348    src.push_str("\n\nint main(int argc, char** argv) {\n");
349
350    for (i, buf) in buffers.iter().enumerate() {
351        src.push_str(&format!("    float* d_buf{i};\n"));
352        src.push_str(&format!(
353            "    cudaMalloc(&d_buf{i}, {});\n",
354            buf.size_bytes()
355        ));
356        src.push_str(&format!(
357            "    FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
358            idx = i + 1
359        ));
360        src.push_str(&format!(
361            "    float* h{i} = (float*)malloc({});\n",
362            buf.size_bytes()
363        ));
364        src.push_str(&format!(
365            "    fread(h{i}, 1, {}, f{i});\n",
366            buf.size_bytes()
367        ));
368        src.push_str(&format!("    fclose(f{i});\n"));
369        src.push_str(&format!(
370            "    cudaMemcpy(d_buf{i}, h{i}, {}, cudaMemcpyHostToDevice);\n",
371            buf.size_bytes()
372        ));
373    }
374
375    let scalar_args: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
376    let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
377    let mut all_args = buf_args;
378    all_args.extend(scalar_args);
379
380    src.push_str(&format!(
381        "    dim3 grid({}, {}, {});\n",
382        grid[0], grid[1], grid[2]
383    ));
384    src.push_str(&format!(
385        "    dim3 block({}, {}, {});\n",
386        workgroup[0], workgroup[1], workgroup[2]
387    ));
388
389    src.push_str(&format!(
390        "    vector_add<<<grid, block>>>({});\n",
391        all_args.join(", ")
392    ));
393    src.push_str("    cudaDeviceSynchronize();\n\n");
394
395    for (i, buf) in buffers.iter().enumerate() {
396        src.push_str(&format!(
397            "    cudaMemcpy(h{i}, d_buf{i}, {}, cudaMemcpyDeviceToHost);\n",
398            buf.size_bytes()
399        ));
400        src.push_str(&format!(
401            "    FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
402            idx = i + 1
403        ));
404        src.push_str(&format!(
405            "    fwrite(h{i}, 1, {}, o{i});\n",
406            buf.size_bytes()
407        ));
408        src.push_str(&format!("    fclose(o{i});\n"));
409        src.push_str(&format!("    cudaFree(d_buf{i});\n"));
410        src.push_str(&format!("    free(h{i});\n"));
411    }
412
413    src.push_str("    return 0;\n}\n");
414    src
415}
416
417#[allow(clippy::needless_pass_by_value)]
418fn generate_hip_host(
419    kernel_code: &str,
420    _buf_paths: &[String],
421    scalars: &[u32],
422    grid: [u32; 3],
423    workgroup: [u32; 3],
424    buffers: &[&mut DeviceBuffer],
425) -> String {
426    let mut src =
427        String::from("#include <hip/hip_runtime.h>\n#include <cstdio>\n#include <cstdlib>\n\n");
428    src.push_str(kernel_code);
429    src.push_str("\n\nint main(int argc, char** argv) {\n");
430
431    for (i, buf) in buffers.iter().enumerate() {
432        src.push_str(&format!("    float* d_buf{i};\n"));
433        src.push_str(&format!(
434            "    hipMalloc(&d_buf{i}, {});\n",
435            buf.size_bytes()
436        ));
437        src.push_str(&format!(
438            "    FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
439            idx = i + 1
440        ));
441        src.push_str(&format!(
442            "    float* h{i} = (float*)malloc({});\n",
443            buf.size_bytes()
444        ));
445        src.push_str(&format!(
446            "    fread(h{i}, 1, {}, f{i});\n",
447            buf.size_bytes()
448        ));
449        src.push_str(&format!("    fclose(f{i});\n"));
450        src.push_str(&format!(
451            "    hipMemcpy(d_buf{i}, h{i}, {}, hipMemcpyHostToDevice);\n",
452            buf.size_bytes()
453        ));
454    }
455
456    let scalar_args: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
457    let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
458    let mut all_args = buf_args;
459    all_args.extend(scalar_args);
460
461    src.push_str(&format!(
462        "    dim3 grid({}, {}, {});\n",
463        grid[0], grid[1], grid[2]
464    ));
465    src.push_str(&format!(
466        "    dim3 block({}, {}, {});\n",
467        workgroup[0], workgroup[1], workgroup[2]
468    ));
469    src.push_str(&format!(
470        "    hipLaunchKernelGGL(vector_add, grid, block, 0, 0, {});\n",
471        all_args.join(", ")
472    ));
473    src.push_str("    hipDeviceSynchronize();\n\n");
474
475    for (i, buf) in buffers.iter().enumerate() {
476        src.push_str(&format!(
477            "    hipMemcpy(h{i}, d_buf{i}, {}, hipMemcpyDeviceToHost);\n",
478            buf.size_bytes()
479        ));
480        src.push_str(&format!(
481            "    FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
482            idx = i + 1
483        ));
484        src.push_str(&format!(
485            "    fwrite(h{i}, 1, {}, o{i});\n",
486            buf.size_bytes()
487        ));
488        src.push_str(&format!("    fclose(o{i});\n"));
489        src.push_str(&format!("    hipFree(d_buf{i});\n"));
490        src.push_str(&format!("    free(h{i});\n"));
491    }
492
493    src.push_str("    return 0;\n}\n");
494    src
495}
496
497#[allow(clippy::needless_pass_by_value)]
498fn generate_sycl_host(
499    kernel_code: &str,
500    _buf_paths: &[String],
501    scalars: &[u32],
502    grid: [u32; 3],
503    workgroup: [u32; 3],
504    buffers: &[&mut DeviceBuffer],
505) -> String {
506    let mut src =
507        String::from("#include <sycl/sycl.hpp>\n#include <cstdio>\n#include <cstdlib>\n\n");
508    src.push_str(kernel_code);
509    src.push_str("\n\nint main(int argc, char** argv) {\n");
510    src.push_str("    sycl::queue q;\n\n");
511
512    for (i, buf) in buffers.iter().enumerate() {
513        let count = buf.count;
514        src.push_str(&format!(
515            "    float* d_buf{i} = sycl::malloc_device<float>({count}, q);\n"
516        ));
517        src.push_str(&format!(
518            "    FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
519            idx = i + 1
520        ));
521        src.push_str(&format!(
522            "    float* h{i} = (float*)malloc({});\n",
523            buf.size_bytes()
524        ));
525        src.push_str(&format!(
526            "    fread(h{i}, 1, {}, f{i});\n",
527            buf.size_bytes()
528        ));
529        src.push_str(&format!("    fclose(f{i});\n"));
530        src.push_str(&format!(
531            "    q.memcpy(d_buf{i}, h{i}, {}).wait();\n",
532            buf.size_bytes()
533        ));
534    }
535
536    let total_threads = grid[0] * workgroup[0];
537    src.push_str(&format!(
538        "    q.parallel_for(sycl::range<1>({total_threads}), [=](sycl::id<1> idx) {{\n"
539    ));
540    src.push_str("        uint32_t gid = idx[0];\n");
541
542    let scalar_vals: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
543    let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
544    let mut all_args = buf_args;
545    all_args.extend(scalar_vals);
546
547    src.push_str(&format!("        kernel_func({});\n", all_args.join(", ")));
548    src.push_str("    }).wait();\n\n");
549
550    for (i, buf) in buffers.iter().enumerate() {
551        src.push_str(&format!(
552            "    q.memcpy(h{i}, d_buf{i}, {}).wait();\n",
553            buf.size_bytes()
554        ));
555        src.push_str(&format!(
556            "    FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
557            idx = i + 1
558        ));
559        src.push_str(&format!(
560            "    fwrite(h{i}, 1, {}, o{i});\n",
561            buf.size_bytes()
562        ));
563        src.push_str(&format!("    fclose(o{i});\n"));
564        src.push_str(&format!("    sycl::free(d_buf{i}, q);\n"));
565        src.push_str(&format!("    free(h{i});\n"));
566    }
567
568    src.push_str("    return 0;\n}\n");
569    src
570}