Skip to main content

09_rnn_image_inference/
09_rnn_image_inference.rs

1use apple_metal::MetalDevice;
2use apple_mps::{
3    feature_channel_format, Image, ImageDescriptor, RnnImageInferenceLayer, RnnSingleGateDescriptor,
4};
5
6fn main() {
7    let device = MetalDevice::system_default().expect("no Metal device available");
8    let queue = device.new_command_queue().expect("command queue");
9
10    let single_gate = RnnSingleGateDescriptor::new(1, 1).expect("single gate descriptor");
11    single_gate.set_use_layer_input_unit_transform_mode(true);
12    let descriptor = single_gate.as_descriptor().expect("base descriptor");
13    let layer = RnnImageInferenceLayer::new(&device, &descriptor).expect("rnn layer");
14
15    let image_descriptor = ImageDescriptor::new(1, 1, 1, feature_channel_format::FLOAT32);
16    let src0 = Image::new(&device, image_descriptor).expect("src0");
17    let src1 = Image::new(&device, image_descriptor).expect("src1");
18    let dst0 = Image::new(&device, image_descriptor).expect("dst0");
19    let dst1 = Image::new(&device, image_descriptor).expect("dst1");
20    src0.write_f32(&[0.25]).expect("write src0");
21    src1.write_f32(&[0.75]).expect("write src1");
22
23    let command_buffer = queue.new_command_buffer().expect("command buffer");
24    let recurrent_state = layer
25        .encode_sequence(&command_buffer, &[&src0, &src1], &[&dst0, &dst1], None)
26        .expect("recurrent state");
27    command_buffer.commit();
28    command_buffer.wait_until_completed();
29
30    let recurrent_output = recurrent_state
31        .recurrent_output_image_for_layer_index(0)
32        .expect("recurrent output image");
33    println!(
34        "{:?} {}x{}x{}",
35        dst1.read_f32().expect("dst1 output"),
36        recurrent_output.width(),
37        recurrent_output.height(),
38        recurrent_output.feature_channels()
39    );
40}