09_rnn_image_inference/
09_rnn_image_inference.rs1use apple_metal::MetalDevice;
2use apple_mps::{feature_channel_format, Image, ImageDescriptor, RnnImageInferenceLayer, RnnSingleGateDescriptor};
3
4fn main() {
5 let device = MetalDevice::system_default().expect("no Metal device available");
6 let queue = device.new_command_queue().expect("command queue");
7
8 let single_gate = RnnSingleGateDescriptor::new(1, 1).expect("single gate descriptor");
9 single_gate.set_use_layer_input_unit_transform_mode(true);
10 let descriptor = single_gate.as_descriptor().expect("base descriptor");
11 let layer = RnnImageInferenceLayer::new(&device, &descriptor).expect("rnn layer");
12
13 let image_descriptor = ImageDescriptor::new(1, 1, 1, feature_channel_format::FLOAT32);
14 let src0 = Image::new(&device, image_descriptor).expect("src0");
15 let src1 = Image::new(&device, image_descriptor).expect("src1");
16 let dst0 = Image::new(&device, image_descriptor).expect("dst0");
17 let dst1 = Image::new(&device, image_descriptor).expect("dst1");
18 src0.write_f32(&[0.25]).expect("write src0");
19 src1.write_f32(&[0.75]).expect("write src1");
20
21 let command_buffer = queue.new_command_buffer().expect("command buffer");
22 let recurrent_state = layer
23 .encode_sequence(&command_buffer, &[&src0, &src1], &[&dst0, &dst1], None)
24 .expect("recurrent state");
25 command_buffer.commit();
26 command_buffer.wait_until_completed();
27
28 let recurrent_output = recurrent_state
29 .recurrent_output_image_for_layer_index(0)
30 .expect("recurrent output image");
31 println!(
32 "{:?} {}x{}x{}",
33 dst1.read_f32().expect("dst1 output"),
34 recurrent_output.width(),
35 recurrent_output.height(),
36 recurrent_output.feature_channels()
37 );
38}