Skip to main content

04_compute_shader/
04_compute_shader.rs

1//! Smoke test for the v0.5 compute pipeline: compiles a trivial
2//! "multiply by 2" Metal kernel, dispatches it on a shared buffer
3//! of 16 floats, and verifies every element doubled.
4
5#![allow(clippy::cast_precision_loss, clippy::float_cmp)]
6
7use apple_metal::{resource_options, MetalDevice};
8
9const KERNEL_SRC: &str = "
10#include <metal_stdlib>
11using namespace metal;
12
13kernel void mul2(device float *data [[buffer(0)]],
14                 uint i [[thread_position_in_grid]]) {
15    data[i] = data[i] * 2.0;
16}
17";
18
19const N: usize = 16;
20
21fn main() {
22    let device = MetalDevice::system_default().expect("MTLCreateSystemDefaultDevice");
23    println!("Device unified={}", device.has_unified_memory());
24
25    let lib = device
26        .new_library_with_source(KERNEL_SRC)
27        .expect("compile MSL source");
28    println!("✅ Compiled library {:p}", lib.as_ptr());
29
30    let func = lib.new_function("mul2").expect("locate function 'mul2'");
31    println!("✅ Found function mul2 {:p}", func.as_ptr());
32
33    let pso = device
34        .new_compute_pipeline_state(&func)
35        .expect("build compute pipeline state");
36    println!("✅ Compute pipeline state {:p}", pso.as_ptr());
37
38    let byte_len = N * core::mem::size_of::<f32>();
39    let buffer = device
40        .new_buffer(byte_len, resource_options::STORAGE_MODE_SHARED)
41        .expect("allocate buffer");
42
43    let slice: &mut [f32] = unsafe {
44        core::slice::from_raw_parts_mut(
45            buffer.contents().expect("buffer.contents").cast::<f32>(),
46            N,
47        )
48    };
49    for (i, x) in slice.iter_mut().enumerate() {
50        *x = i as f32;
51    }
52    println!("Input : {slice:?}");
53
54    let queue = device.new_command_queue().expect("MTLCommandQueue");
55    let cb = queue.new_command_buffer().expect("MTLCommandBuffer");
56    let ok = cb.dispatch_compute_1d(&pso, &[&buffer], N, 1);
57    assert!(ok, "dispatch_compute_1d failed");
58    cb.commit();
59    cb.wait_until_completed();
60
61    let slice: &[f32] = unsafe {
62        core::slice::from_raw_parts(buffer.contents().expect("buffer.contents").cast::<f32>(), N)
63    };
64    println!("Output: {slice:?}");
65
66    for (i, &v) in slice.iter().enumerate() {
67        let expected = (i as f32) * 2.0;
68        assert_eq!(v, expected, "element {i} expected {expected} got {v}");
69    }
70    println!("✅ All {N} elements correctly doubled by the GPU kernel");
71}