Skip to main content

05_nn_graph_relu/
05_nn_graph_relu.rs

1use apple_metal::MetalDevice;
2use apple_mps::{
3    feature_channel_format, rnn_sequence_direction, CnnConvolutionDescriptor, CnnNeuronReluNode,
4    CnnPoolingMaxNode, CnnSoftMaxNode, CnnUpsamplingNearestNode, Image, ImageDescriptor, NNGraph,
5    NNImageNode, RnnSingleGateDescriptor,
6};
7
8fn main() {
9    let device = MetalDevice::system_default().expect("no Metal device available");
10    let queue = device.new_command_queue().expect("command queue");
11
12    let input_node = NNImageNode::new().expect("input node");
13    input_node.set_format(feature_channel_format::FLOAT32);
14    let relu = CnnNeuronReluNode::new(&input_node, 0.0).expect("relu node");
15    let relu_result = relu.result_image().expect("relu result image");
16    relu_result.set_format(feature_channel_format::FLOAT32);
17    relu_result.set_synchronize_resource(true);
18    relu_result.use_default_allocator();
19    let pooling = CnnPoolingMaxNode::new(&relu_result, 2, 2).expect("pooling node");
20    assert!(
21        pooling.result_image().is_some(),
22        "pooling result image should exist"
23    );
24    let softmax = CnnSoftMaxNode::new(&relu_result).expect("softmax node");
25    assert!(
26        softmax.result_image().is_some(),
27        "softmax result image should exist"
28    );
29    let upsampling = CnnUpsamplingNearestNode::new(&relu_result, 2, 2).expect("upsampling node");
30    assert!(
31        upsampling.result_image().is_some(),
32        "upsampling result image should exist"
33    );
34
35    let graph = NNGraph::new(&device, &relu_result, true).expect("graph");
36    graph.set_format(feature_channel_format::FLOAT32);
37    graph.use_default_destination_image_allocator();
38
39    let descriptor = ImageDescriptor::new(2, 2, 1, feature_channel_format::FLOAT32);
40    let source = Image::new(&device, descriptor).expect("source image");
41    source
42        .write_f32(&[-1.0, 0.5, 2.0, -0.25])
43        .expect("write source image");
44
45    let command_buffer = queue.new_command_buffer().expect("command buffer");
46    let result = graph
47        .encode(&command_buffer, &[&source])
48        .expect("graph encode");
49    command_buffer.commit();
50    command_buffer.wait_until_completed();
51
52    let output = result.read_f32().expect("read graph output");
53    let expected = [0.0_f32, 0.5, 2.0, 0.0];
54    for (actual, expected_value) in output.iter().zip(expected) {
55        assert!(
56            (actual - expected_value).abs() < 1.0e-4,
57            "unexpected relu graph output: {output:?}"
58        );
59    }
60
61    let convolution = CnnConvolutionDescriptor::new(3, 3, 1, 4).expect("convolution descriptor");
62    convolution.set_stride_in_pixels_x(2);
63    convolution.set_stride_in_pixels_y(1);
64    convolution.set_groups(1);
65    convolution.set_dilation_rate_x(1);
66    convolution.set_dilation_rate_y(2);
67    assert_eq!(convolution.kernel_width(), 3);
68    assert_eq!(convolution.kernel_height(), 3);
69    assert_eq!(convolution.stride_in_pixels_x(), 2);
70    assert_eq!(convolution.stride_in_pixels_y(), 1);
71    assert_eq!(convolution.groups(), 1);
72    assert_eq!(convolution.dilation_rate_x(), 1);
73    assert_eq!(convolution.dilation_rate_y(), 2);
74
75    let rnn = RnnSingleGateDescriptor::new(3, 5).expect("rnn descriptor");
76    rnn.set_use_layer_input_unit_transform_mode(true);
77    rnn.set_use_float32_weights(true);
78    rnn.set_layer_sequence_direction(rnn_sequence_direction::BACKWARD);
79    assert_eq!(rnn.input_feature_channels(), 3);
80    assert_eq!(rnn.output_feature_channels(), 5);
81    assert!(rnn.use_layer_input_unit_transform_mode());
82    assert!(rnn.use_float32_weights());
83    assert_eq!(
84        rnn.layer_sequence_direction(),
85        rnn_sequence_direction::BACKWARD,
86        "expected backward RNN sequence direction"
87    );
88
89    println!(
90        "nn smoke passed: relu={output:?} source_images={}",
91        graph.source_image_count()
92    );
93}