Skip to main content

09_rnn_image_inference/
09_rnn_image_inference.rs

1use apple_metal::MetalDevice;
2use apple_mps::{feature_channel_format, Image, ImageDescriptor, RnnImageInferenceLayer, RnnSingleGateDescriptor};
3
4fn main() {
5    let device = MetalDevice::system_default().expect("no Metal device available");
6    let queue = device.new_command_queue().expect("command queue");
7
8    let single_gate = RnnSingleGateDescriptor::new(1, 1).expect("single gate descriptor");
9    single_gate.set_use_layer_input_unit_transform_mode(true);
10    let descriptor = single_gate.as_descriptor().expect("base descriptor");
11    let layer = RnnImageInferenceLayer::new(&device, &descriptor).expect("rnn layer");
12
13    let image_descriptor = ImageDescriptor::new(1, 1, 1, feature_channel_format::FLOAT32);
14    let src0 = Image::new(&device, image_descriptor).expect("src0");
15    let src1 = Image::new(&device, image_descriptor).expect("src1");
16    let dst0 = Image::new(&device, image_descriptor).expect("dst0");
17    let dst1 = Image::new(&device, image_descriptor).expect("dst1");
18    src0.write_f32(&[0.25]).expect("write src0");
19    src1.write_f32(&[0.75]).expect("write src1");
20
21    let command_buffer = queue.new_command_buffer().expect("command buffer");
22    let recurrent_state = layer
23        .encode_sequence(&command_buffer, &[&src0, &src1], &[&dst0, &dst1], None)
24        .expect("recurrent state");
25    command_buffer.commit();
26    command_buffer.wait_until_completed();
27
28    let recurrent_output = recurrent_state
29        .recurrent_output_image_for_layer_index(0)
30        .expect("recurrent output image");
31    println!(
32        "{:?} {}x{}x{}",
33        dst1.read_f32().expect("dst1 output"),
34        recurrent_output.width(),
35        recurrent_output.height(),
36        recurrent_output.feature_channels()
37    );
38}