09_rnn_image_inference/
09_rnn_image_inference.rs1use apple_metal::MetalDevice;
2use apple_mps::{
3 feature_channel_format, Image, ImageDescriptor, RnnImageInferenceLayer, RnnSingleGateDescriptor,
4};
5
6fn main() {
7 let device = MetalDevice::system_default().expect("no Metal device available");
8 let queue = device.new_command_queue().expect("command queue");
9
10 let single_gate = RnnSingleGateDescriptor::new(1, 1).expect("single gate descriptor");
11 single_gate.set_use_layer_input_unit_transform_mode(true);
12 let descriptor = single_gate.as_descriptor().expect("base descriptor");
13 let layer = RnnImageInferenceLayer::new(&device, &descriptor).expect("rnn layer");
14
15 let image_descriptor = ImageDescriptor::new(1, 1, 1, feature_channel_format::FLOAT32);
16 let src0 = Image::new(&device, image_descriptor).expect("src0");
17 let src1 = Image::new(&device, image_descriptor).expect("src1");
18 let dst0 = Image::new(&device, image_descriptor).expect("dst0");
19 let dst1 = Image::new(&device, image_descriptor).expect("dst1");
20 src0.write_f32(&[0.25]).expect("write src0");
21 src1.write_f32(&[0.75]).expect("write src1");
22
23 let command_buffer = queue.new_command_buffer().expect("command buffer");
24 let recurrent_state = layer
25 .encode_sequence(&command_buffer, &[&src0, &src1], &[&dst0, &dst1], None)
26 .expect("recurrent state");
27 command_buffer.commit();
28 command_buffer.wait_until_completed();
29
30 let recurrent_output = recurrent_state
31 .recurrent_output_image_for_layer_index(0)
32 .expect("recurrent output image");
33 println!(
34 "{:?} {}x{}x{}",
35 dst1.read_f32().expect("dst1 output"),
36 recurrent_output.width(),
37 recurrent_output.height(),
38 recurrent_output.feature_channels()
39 );
40}