1#![allow(clippy::format_push_string)]
12
13use crate::device::GpuVendor;
14use crate::error::RuntimeError;
15use crate::memory::DeviceBuffer;
16use std::fs;
17use std::path::Path;
18use std::process::Command;
19
20#[allow(clippy::too_many_arguments)]
27pub fn launch_kernel(
28 vendor_code: &str,
29 wbin: &[u8],
30 vendor: GpuVendor,
31 buffers: &mut [&mut DeviceBuffer],
32 scalars: &[u32],
33 grid: [u32; 3],
34 workgroup: [u32; 3],
35) -> Result<(), RuntimeError> {
36 match vendor {
37 GpuVendor::Apple => launch_metal(vendor_code, buffers, scalars, grid, workgroup),
38 GpuVendor::Nvidia => launch_cuda(vendor_code, buffers, scalars, grid, workgroup),
39 GpuVendor::Amd => launch_hip(vendor_code, buffers, scalars, grid, workgroup),
40 GpuVendor::Intel => launch_sycl(vendor_code, buffers, scalars, grid, workgroup),
41 GpuVendor::Emulator => launch_emulator(wbin, buffers, scalars, grid, workgroup),
42 }
43}
44
45fn launch_emulator(
46 wbin: &[u8],
47 buffers: &mut [&mut DeviceBuffer],
48 scalars: &[u32],
49 grid: [u32; 3],
50 workgroup: [u32; 3],
51) -> Result<(), RuntimeError> {
52 let total_buffer_bytes: usize = buffers.iter().map(|b| b.size_bytes()).sum();
53 let device_mem_size = (total_buffer_bytes + 4096).max(1024 * 1024);
54
55 let mut config = wave_emu::EmulatorConfig {
56 grid_dim: grid,
57 workgroup_dim: workgroup,
58 device_memory_size: device_mem_size,
59 ..wave_emu::EmulatorConfig::default()
60 };
61
62 let mut initial_regs: Vec<(u8, u32)> = Vec::new();
63 let mut offset: u32 = 0;
64 for (i, buf) in buffers.iter().enumerate() {
65 #[allow(clippy::cast_possible_truncation)]
66 let reg_idx = i as u8;
67 initial_regs.push((reg_idx, offset));
68 #[allow(clippy::cast_possible_truncation)]
69 let size = buf.size_bytes() as u32;
70 offset += size;
71 }
72 for (i, &scalar) in scalars.iter().enumerate() {
73 #[allow(clippy::cast_possible_truncation)]
74 let reg_idx = (buffers.len() + i) as u8;
75 initial_regs.push((reg_idx, scalar));
76 }
77 config.initial_registers = initial_regs;
78
79 let mut emu = wave_emu::Emulator::new(config);
80 emu.load_binary(wbin)?;
81
82 let mut mem_offset: u64 = 0;
83 for buf in buffers.iter() {
84 emu.load_device_memory(mem_offset, &buf.data)?;
85 #[allow(clippy::cast_possible_truncation)]
86 let size = buf.size_bytes() as u64;
87 mem_offset += size;
88 }
89
90 emu.run()?;
91
92 let mut read_offset: u64 = 0;
93 for buf in buffers.iter_mut() {
94 let size = buf.size_bytes();
95 let result = emu.read_device_memory(read_offset, size)?;
96 buf.data = result;
97 read_offset += size as u64;
98 }
99
100 Ok(())
101}
102
103fn write_buffer_files(
104 dir: &Path,
105 buffers: &[&mut DeviceBuffer],
106) -> Result<Vec<String>, RuntimeError> {
107 let mut paths = Vec::new();
108 for (i, buf) in buffers.iter().enumerate() {
109 let path = dir.join(format!("buf_{i}.bin"));
110 fs::write(&path, &buf.data)?;
111 paths.push(
112 path.to_str()
113 .ok_or_else(|| RuntimeError::Io("invalid path".into()))?
114 .to_string(),
115 );
116 }
117 Ok(paths)
118}
119
120fn read_buffer_files(dir: &Path, buffers: &mut [&mut DeviceBuffer]) -> Result<(), RuntimeError> {
121 for (i, buf) in buffers.iter_mut().enumerate() {
122 let path = dir.join(format!("buf_{i}.bin"));
123 buf.data = fs::read(&path)?;
124 }
125 Ok(())
126}
127
128fn launch_metal(
129 vendor_code: &str,
130 buffers: &mut [&mut DeviceBuffer],
131 scalars: &[u32],
132 grid: [u32; 3],
133 workgroup: [u32; 3],
134) -> Result<(), RuntimeError> {
135 let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
136 let metal_path = dir.path().join("kernel.metal");
137 fs::write(&metal_path, vendor_code)?;
138
139 let buf_paths = write_buffer_files(dir.path(), buffers)?;
140
141 let host_src = generate_metal_host(&buf_paths, scalars, grid, workgroup);
142 let host_path = dir.path().join("host.swift");
143 fs::write(&host_path, &host_src)?;
144
145 let lib_path = dir.path().join("kernel.metallib");
146 let status = Command::new("xcrun")
147 .args(["-sdk", "macosx", "metal", "-o"])
148 .arg(&lib_path)
149 .arg(&metal_path)
150 .status()?;
151 if !status.success() {
152 return Err(RuntimeError::Launch(
153 "Metal shader compilation failed".into(),
154 ));
155 }
156
157 let exe_path = dir.path().join("host");
158 let status = Command::new("swiftc")
159 .arg("-o")
160 .arg(&exe_path)
161 .arg(&host_path)
162 .arg("-framework")
163 .arg("Metal")
164 .arg("-framework")
165 .arg("Foundation")
166 .status()?;
167 if !status.success() {
168 return Err(RuntimeError::Launch("Swift host compilation failed".into()));
169 }
170
171 let status = Command::new(&exe_path).arg(&lib_path).status()?;
172 if !status.success() {
173 return Err(RuntimeError::Launch("Metal kernel execution failed".into()));
174 }
175
176 read_buffer_files(dir.path(), buffers)?;
177 Ok(())
178}
179
180fn launch_cuda(
181 vendor_code: &str,
182 buffers: &mut [&mut DeviceBuffer],
183 scalars: &[u32],
184 grid: [u32; 3],
185 workgroup: [u32; 3],
186) -> Result<(), RuntimeError> {
187 let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
188
189 let host_src = generate_cuda_host(vendor_code, &[], scalars, grid, workgroup, buffers);
190 let cu_path = dir.path().join("kernel.cu");
191 fs::write(&cu_path, &host_src)?;
192
193 let buf_paths = write_buffer_files(dir.path(), buffers)?;
194
195 let exe_path = dir.path().join("kernel");
196 let status = Command::new("nvcc")
197 .arg("-o")
198 .arg(&exe_path)
199 .arg(&cu_path)
200 .status()?;
201 if !status.success() {
202 return Err(RuntimeError::Launch("CUDA compilation failed".into()));
203 }
204
205 let status = Command::new(&exe_path).args(&buf_paths).status()?;
206 if !status.success() {
207 return Err(RuntimeError::Launch("CUDA kernel execution failed".into()));
208 }
209
210 read_buffer_files(dir.path(), buffers)?;
211 Ok(())
212}
213
214fn launch_hip(
215 vendor_code: &str,
216 buffers: &mut [&mut DeviceBuffer],
217 scalars: &[u32],
218 grid: [u32; 3],
219 workgroup: [u32; 3],
220) -> Result<(), RuntimeError> {
221 let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
222
223 let host_src = generate_hip_host(vendor_code, &[], scalars, grid, workgroup, buffers);
224 let hip_path = dir.path().join("kernel.hip");
225 fs::write(&hip_path, &host_src)?;
226
227 let buf_paths = write_buffer_files(dir.path(), buffers)?;
228
229 let exe_path = dir.path().join("kernel");
230 let status = Command::new("hipcc")
231 .arg("-o")
232 .arg(&exe_path)
233 .arg(&hip_path)
234 .status()?;
235 if !status.success() {
236 return Err(RuntimeError::Launch("HIP compilation failed".into()));
237 }
238
239 let status = Command::new(&exe_path).args(&buf_paths).status()?;
240 if !status.success() {
241 return Err(RuntimeError::Launch("HIP kernel execution failed".into()));
242 }
243
244 read_buffer_files(dir.path(), buffers)?;
245 Ok(())
246}
247
248fn launch_sycl(
249 vendor_code: &str,
250 buffers: &mut [&mut DeviceBuffer],
251 scalars: &[u32],
252 grid: [u32; 3],
253 workgroup: [u32; 3],
254) -> Result<(), RuntimeError> {
255 let dir = tempfile::tempdir().map_err(|e| RuntimeError::Io(e.to_string()))?;
256
257 let host_src = generate_sycl_host(vendor_code, &[], scalars, grid, workgroup, buffers);
258 let cpp_path = dir.path().join("kernel.cpp");
259 fs::write(&cpp_path, &host_src)?;
260
261 let buf_paths = write_buffer_files(dir.path(), buffers)?;
262
263 let exe_path = dir.path().join("kernel");
264 let status = Command::new("icpx")
265 .arg("-fsycl")
266 .arg("-o")
267 .arg(&exe_path)
268 .arg(&cpp_path)
269 .status()?;
270 if !status.success() {
271 return Err(RuntimeError::Launch("SYCL compilation failed".into()));
272 }
273
274 let status = Command::new(&exe_path).args(&buf_paths).status()?;
275 if !status.success() {
276 return Err(RuntimeError::Launch("SYCL kernel execution failed".into()));
277 }
278
279 read_buffer_files(dir.path(), buffers)?;
280 Ok(())
281}
282
283fn generate_metal_host(
284 buf_paths: &[String],
285 scalars: &[u32],
286 grid: [u32; 3],
287 workgroup: [u32; 3],
288) -> String {
289 let mut src = String::from("import Metal\nimport Foundation\n\n");
290 src.push_str("let device = MTLCreateSystemDefaultDevice()!\n");
291 src.push_str("let lib = try! device.makeLibrary(filepath: CommandLine.arguments[1])\n");
292 src.push_str(
293 "let function = lib.functionNames.first.flatMap { lib.makeFunction(name: $0) }!\n",
294 );
295 src.push_str("let pipeline = try! device.makeComputePipelineState(function: function)\n");
296 src.push_str("let queue = device.makeCommandQueue()!\n");
297 src.push_str("let cmd = queue.makeCommandBuffer()!\n");
298 src.push_str("let enc = cmd.makeComputeCommandEncoder()!\n");
299 src.push_str("enc.setComputePipelineState(pipeline)\n\n");
300
301 for (i, path) in buf_paths.iter().enumerate() {
302 src.push_str(&format!(
303 "let data{i} = try! Data(contentsOf: URL(fileURLWithPath: \"{path}\"))\n"
304 ));
305 src.push_str(&format!(
306 "let buf{i} = device.makeBuffer(bytes: (data{i} as NSData).bytes, length: data{i}.count, options: .storageModeShared)!\n"
307 ));
308 src.push_str(&format!("enc.setBuffer(buf{i}, offset: 0, index: {i})\n"));
309 }
310
311 for (i, &s) in scalars.iter().enumerate() {
312 let idx = buf_paths.len() + i;
313 src.push_str(&format!("var scalar{i}: UInt32 = {s}\n"));
314 src.push_str(&format!(
315 "enc.setBytes(&scalar{i}, length: 4, index: {idx})\n"
316 ));
317 }
318
319 src.push_str(&format!(
320 "\nenc.dispatchThreadgroups(MTLSize(width: {}, height: {}, depth: {}), threadsPerThreadgroup: MTLSize(width: {}, height: {}, depth: {}))\n",
321 grid[0], grid[1], grid[2], workgroup[0], workgroup[1], workgroup[2]
322 ));
323 src.push_str("enc.endEncoding()\ncmd.commit()\ncmd.waitUntilCompleted()\n\n");
324
325 for (i, path) in buf_paths.iter().enumerate() {
326 src.push_str(&format!(
327 "let out{i} = Data(bytes: buf{i}.contents(), count: buf{i}.length)\n"
328 ));
329 src.push_str(&format!(
330 "try! out{i}.write(to: URL(fileURLWithPath: \"{path}\"))\n"
331 ));
332 }
333
334 src
335}
336
337#[allow(clippy::needless_pass_by_value)]
338fn generate_cuda_host(
339 kernel_code: &str,
340 _buf_paths: &[String],
341 scalars: &[u32],
342 grid: [u32; 3],
343 workgroup: [u32; 3],
344 buffers: &[&mut DeviceBuffer],
345) -> String {
346 let mut src = String::from("#include <cstdio>\n#include <cstdlib>\n#include <cstring>\n\n");
347 src.push_str(kernel_code);
348 src.push_str("\n\nint main(int argc, char** argv) {\n");
349
350 for (i, buf) in buffers.iter().enumerate() {
351 src.push_str(&format!(" float* d_buf{i};\n"));
352 src.push_str(&format!(
353 " cudaMalloc(&d_buf{i}, {});\n",
354 buf.size_bytes()
355 ));
356 src.push_str(&format!(
357 " FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
358 idx = i + 1
359 ));
360 src.push_str(&format!(
361 " float* h{i} = (float*)malloc({});\n",
362 buf.size_bytes()
363 ));
364 src.push_str(&format!(
365 " fread(h{i}, 1, {}, f{i});\n",
366 buf.size_bytes()
367 ));
368 src.push_str(&format!(" fclose(f{i});\n"));
369 src.push_str(&format!(
370 " cudaMemcpy(d_buf{i}, h{i}, {}, cudaMemcpyHostToDevice);\n",
371 buf.size_bytes()
372 ));
373 }
374
375 let scalar_args: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
376 let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
377 let mut all_args = buf_args;
378 all_args.extend(scalar_args);
379
380 src.push_str(&format!(
381 " dim3 grid({}, {}, {});\n",
382 grid[0], grid[1], grid[2]
383 ));
384 src.push_str(&format!(
385 " dim3 block({}, {}, {});\n",
386 workgroup[0], workgroup[1], workgroup[2]
387 ));
388
389 src.push_str(&format!(
390 " vector_add<<<grid, block>>>({});\n",
391 all_args.join(", ")
392 ));
393 src.push_str(" cudaDeviceSynchronize();\n\n");
394
395 for (i, buf) in buffers.iter().enumerate() {
396 src.push_str(&format!(
397 " cudaMemcpy(h{i}, d_buf{i}, {}, cudaMemcpyDeviceToHost);\n",
398 buf.size_bytes()
399 ));
400 src.push_str(&format!(
401 " FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
402 idx = i + 1
403 ));
404 src.push_str(&format!(
405 " fwrite(h{i}, 1, {}, o{i});\n",
406 buf.size_bytes()
407 ));
408 src.push_str(&format!(" fclose(o{i});\n"));
409 src.push_str(&format!(" cudaFree(d_buf{i});\n"));
410 src.push_str(&format!(" free(h{i});\n"));
411 }
412
413 src.push_str(" return 0;\n}\n");
414 src
415}
416
417#[allow(clippy::needless_pass_by_value)]
418fn generate_hip_host(
419 kernel_code: &str,
420 _buf_paths: &[String],
421 scalars: &[u32],
422 grid: [u32; 3],
423 workgroup: [u32; 3],
424 buffers: &[&mut DeviceBuffer],
425) -> String {
426 let mut src =
427 String::from("#include <hip/hip_runtime.h>\n#include <cstdio>\n#include <cstdlib>\n\n");
428 src.push_str(kernel_code);
429 src.push_str("\n\nint main(int argc, char** argv) {\n");
430
431 for (i, buf) in buffers.iter().enumerate() {
432 src.push_str(&format!(" float* d_buf{i};\n"));
433 src.push_str(&format!(
434 " hipMalloc(&d_buf{i}, {});\n",
435 buf.size_bytes()
436 ));
437 src.push_str(&format!(
438 " FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
439 idx = i + 1
440 ));
441 src.push_str(&format!(
442 " float* h{i} = (float*)malloc({});\n",
443 buf.size_bytes()
444 ));
445 src.push_str(&format!(
446 " fread(h{i}, 1, {}, f{i});\n",
447 buf.size_bytes()
448 ));
449 src.push_str(&format!(" fclose(f{i});\n"));
450 src.push_str(&format!(
451 " hipMemcpy(d_buf{i}, h{i}, {}, hipMemcpyHostToDevice);\n",
452 buf.size_bytes()
453 ));
454 }
455
456 let scalar_args: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
457 let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
458 let mut all_args = buf_args;
459 all_args.extend(scalar_args);
460
461 src.push_str(&format!(
462 " dim3 grid({}, {}, {});\n",
463 grid[0], grid[1], grid[2]
464 ));
465 src.push_str(&format!(
466 " dim3 block({}, {}, {});\n",
467 workgroup[0], workgroup[1], workgroup[2]
468 ));
469 src.push_str(&format!(
470 " hipLaunchKernelGGL(vector_add, grid, block, 0, 0, {});\n",
471 all_args.join(", ")
472 ));
473 src.push_str(" hipDeviceSynchronize();\n\n");
474
475 for (i, buf) in buffers.iter().enumerate() {
476 src.push_str(&format!(
477 " hipMemcpy(h{i}, d_buf{i}, {}, hipMemcpyDeviceToHost);\n",
478 buf.size_bytes()
479 ));
480 src.push_str(&format!(
481 " FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
482 idx = i + 1
483 ));
484 src.push_str(&format!(
485 " fwrite(h{i}, 1, {}, o{i});\n",
486 buf.size_bytes()
487 ));
488 src.push_str(&format!(" fclose(o{i});\n"));
489 src.push_str(&format!(" hipFree(d_buf{i});\n"));
490 src.push_str(&format!(" free(h{i});\n"));
491 }
492
493 src.push_str(" return 0;\n}\n");
494 src
495}
496
497#[allow(clippy::needless_pass_by_value)]
498fn generate_sycl_host(
499 kernel_code: &str,
500 _buf_paths: &[String],
501 scalars: &[u32],
502 grid: [u32; 3],
503 workgroup: [u32; 3],
504 buffers: &[&mut DeviceBuffer],
505) -> String {
506 let mut src =
507 String::from("#include <sycl/sycl.hpp>\n#include <cstdio>\n#include <cstdlib>\n\n");
508 src.push_str(kernel_code);
509 src.push_str("\n\nint main(int argc, char** argv) {\n");
510 src.push_str(" sycl::queue q;\n\n");
511
512 for (i, buf) in buffers.iter().enumerate() {
513 let count = buf.count;
514 src.push_str(&format!(
515 " float* d_buf{i} = sycl::malloc_device<float>({count}, q);\n"
516 ));
517 src.push_str(&format!(
518 " FILE* f{i} = fopen(argv[{idx}], \"rb\");\n",
519 idx = i + 1
520 ));
521 src.push_str(&format!(
522 " float* h{i} = (float*)malloc({});\n",
523 buf.size_bytes()
524 ));
525 src.push_str(&format!(
526 " fread(h{i}, 1, {}, f{i});\n",
527 buf.size_bytes()
528 ));
529 src.push_str(&format!(" fclose(f{i});\n"));
530 src.push_str(&format!(
531 " q.memcpy(d_buf{i}, h{i}, {}).wait();\n",
532 buf.size_bytes()
533 ));
534 }
535
536 let total_threads = grid[0] * workgroup[0];
537 src.push_str(&format!(
538 " q.parallel_for(sycl::range<1>({total_threads}), [=](sycl::id<1> idx) {{\n"
539 ));
540 src.push_str(" uint32_t gid = idx[0];\n");
541
542 let scalar_vals: Vec<String> = scalars.iter().map(|s| format!("{s}")).collect();
543 let buf_args: Vec<String> = (0..buffers.len()).map(|i| format!("d_buf{i}")).collect();
544 let mut all_args = buf_args;
545 all_args.extend(scalar_vals);
546
547 src.push_str(&format!(" kernel_func({});\n", all_args.join(", ")));
548 src.push_str(" }).wait();\n\n");
549
550 for (i, buf) in buffers.iter().enumerate() {
551 src.push_str(&format!(
552 " q.memcpy(h{i}, d_buf{i}, {}).wait();\n",
553 buf.size_bytes()
554 ));
555 src.push_str(&format!(
556 " FILE* o{i} = fopen(argv[{idx}], \"wb\");\n",
557 idx = i + 1
558 ));
559 src.push_str(&format!(
560 " fwrite(h{i}, 1, {}, o{i});\n",
561 buf.size_bytes()
562 ));
563 src.push_str(&format!(" fclose(o{i});\n"));
564 src.push_str(&format!(" sycl::free(d_buf{i}, q);\n"));
565 src.push_str(&format!(" free(h{i});\n"));
566 }
567
568 src.push_str(" return 0;\n}\n");
569 src
570}