Skip to main content

06_ndarray_matrix_multiplication/
06_ndarray_matrix_multiplication.rs

1use apple_metal::{resource_options, MetalBuffer, MetalDevice};
2use apple_mps::{data_type, NDArray, NDArrayDescriptor, NDArrayMatrixMultiplication};
3
4fn as_bytes<T>(values: &[T]) -> &[u8] {
5    unsafe {
6        core::slice::from_raw_parts(values.as_ptr().cast::<u8>(), core::mem::size_of_val(values))
7    }
8}
9
10fn buffer_with_f32_values_padded(device: &MetalDevice, values: &[f32], byte_len: usize) -> MetalBuffer {
11    let buffer = device
12        .new_buffer(
13            byte_len.max(core::mem::size_of_val(values)),
14            resource_options::STORAGE_MODE_SHARED,
15        )
16        .expect("buffer");
17    let _ = buffer.write_bytes(as_bytes(values));
18    buffer
19}
20
21fn read_f32_values(buffer: &MetalBuffer, len: usize) -> Vec<f32> {
22    let ptr = buffer.contents().expect("buffer contents").cast::<f32>();
23    unsafe { core::slice::from_raw_parts(ptr, len).to_vec() }
24}
25
26fn main() {
27    let device = MetalDevice::system_default().expect("no Metal device available");
28    let queue = device.new_command_queue().expect("command queue");
29
30    let descriptor = NDArrayDescriptor::with_dimension_sizes(data_type::FLOAT32, &[2, 2, 1, 1]).expect("descriptor");
31    let template = NDArray::new(&device, &descriptor).expect("template ndarray");
32    let byte_len = template.resource_size();
33    let rows = descriptor.length_of_dimension(1);
34    let row_stride_floats = byte_len / core::mem::size_of::<f32>() / rows;
35    let left_buffer = buffer_with_f32_values_padded(
36        &device,
37        &[1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0],
38        byte_len,
39    );
40    let right_buffer = buffer_with_f32_values_padded(
41        &device,
42        &[5.0, 6.0, 0.0, 0.0, 7.0, 8.0, 0.0, 0.0],
43        byte_len,
44    );
45    let destination_buffer = buffer_with_f32_values_padded(&device, &[0.0; 8], byte_len);
46
47    let left = NDArray::new_with_buffer(&left_buffer, 0, &descriptor).expect("left ndarray");
48    let right = NDArray::new_with_buffer(&right_buffer, 0, &descriptor).expect("right ndarray");
49    let destination = NDArray::new_with_buffer(&destination_buffer, 0, &descriptor).expect("destination ndarray");
50
51    let kernel = NDArrayMatrixMultiplication::new(&device, 2).expect("ndarray matmul");
52    kernel.set_alpha(1.0);
53    kernel.set_beta(0.0);
54
55    let command_buffer = queue.new_command_buffer().expect("command buffer");
56    kernel.encode_to_destination(&command_buffer, &[&left, &right], &destination);
57    command_buffer.commit();
58    command_buffer.wait_until_completed();
59
60    let padded_output = read_f32_values(&destination_buffer, row_stride_floats * rows);
61    let output = [
62        padded_output[0],
63        padded_output[1],
64        padded_output[row_stride_floats],
65        padded_output[row_stride_floats + 1],
66    ];
67    println!("{output:?}");
68}