04_compute_shader/
04_compute_shader.rs1#![allow(clippy::cast_precision_loss, clippy::float_cmp)]
6
7use apple_metal::{resource_options, MetalDevice};
8
9const KERNEL_SRC: &str = "
10#include <metal_stdlib>
11using namespace metal;
12
13kernel void mul2(device float *data [[buffer(0)]],
14 uint i [[thread_position_in_grid]]) {
15 data[i] = data[i] * 2.0;
16}
17";
18
19const N: usize = 16;
20
21fn main() {
22 let device = MetalDevice::system_default().expect("MTLCreateSystemDefaultDevice");
23 println!("Device unified={}", device.has_unified_memory());
24
25 let lib = device
26 .new_library_with_source(KERNEL_SRC)
27 .expect("compile MSL source");
28 println!("✅ Compiled library {:p}", lib.as_ptr());
29
30 let func = lib.new_function("mul2").expect("locate function 'mul2'");
31 println!("✅ Found function mul2 {:p}", func.as_ptr());
32
33 let pso = device
34 .new_compute_pipeline_state(&func)
35 .expect("build compute pipeline state");
36 println!("✅ Compute pipeline state {:p}", pso.as_ptr());
37
38 let byte_len = N * core::mem::size_of::<f32>();
39 let buffer = device
40 .new_buffer(byte_len, resource_options::STORAGE_MODE_SHARED)
41 .expect("allocate buffer");
42
43 let slice: &mut [f32] = unsafe {
44 core::slice::from_raw_parts_mut(
45 buffer.contents().expect("buffer.contents").cast::<f32>(),
46 N,
47 )
48 };
49 for (i, x) in slice.iter_mut().enumerate() {
50 *x = i as f32;
51 }
52 println!("Input : {slice:?}");
53
54 let queue = device.new_command_queue().expect("MTLCommandQueue");
55 let cb = queue.new_command_buffer().expect("MTLCommandBuffer");
56 let ok = cb.dispatch_compute_1d(&pso, &[&buffer], N, 1);
57 assert!(ok, "dispatch_compute_1d failed");
58 cb.commit();
59 cb.wait_until_completed();
60
61 let slice: &[f32] = unsafe {
62 core::slice::from_raw_parts(buffer.contents().expect("buffer.contents").cast::<f32>(), N)
63 };
64 println!("Output: {slice:?}");
65
66 for (i, &v) in slice.iter().enumerate() {
67 let expected = (i as f32) * 2.0;
68 assert_eq!(v, expected, "element {i} expected {expected} got {v}");
69 }
70 println!("✅ All {N} elements correctly doubled by the GPU kernel");
71}