pub struct CnnSoftMaxNode { /* private fields */ }Implementations§
Source§impl CnnSoftMaxNode
impl CnnSoftMaxNode
Sourcepub fn new(source: &NNImageNode) -> Option<Self>
pub fn new(source: &NNImageNode) -> Option<Self>
Examples found in repository?
examples/05_nn_graph_relu.rs (line 24)
8fn main() {
9 let device = MetalDevice::system_default().expect("no Metal device available");
10 let queue = device.new_command_queue().expect("command queue");
11
12 let input_node = NNImageNode::new().expect("input node");
13 input_node.set_format(feature_channel_format::FLOAT32);
14 let relu = CnnNeuronReluNode::new(&input_node, 0.0).expect("relu node");
15 let relu_result = relu.result_image().expect("relu result image");
16 relu_result.set_format(feature_channel_format::FLOAT32);
17 relu_result.set_synchronize_resource(true);
18 relu_result.use_default_allocator();
19 let pooling = CnnPoolingMaxNode::new(&relu_result, 2, 2).expect("pooling node");
20 assert!(
21 pooling.result_image().is_some(),
22 "pooling result image should exist"
23 );
24 let softmax = CnnSoftMaxNode::new(&relu_result).expect("softmax node");
25 assert!(
26 softmax.result_image().is_some(),
27 "softmax result image should exist"
28 );
29 let upsampling = CnnUpsamplingNearestNode::new(&relu_result, 2, 2).expect("upsampling node");
30 assert!(
31 upsampling.result_image().is_some(),
32 "upsampling result image should exist"
33 );
34
35 let graph = NNGraph::new(&device, &relu_result, true).expect("graph");
36 graph.set_format(feature_channel_format::FLOAT32);
37 graph.use_default_destination_image_allocator();
38
39 let descriptor = ImageDescriptor::new(2, 2, 1, feature_channel_format::FLOAT32);
40 let source = Image::new(&device, descriptor).expect("source image");
41 source
42 .write_f32(&[-1.0, 0.5, 2.0, -0.25])
43 .expect("write source image");
44
45 let command_buffer = queue.new_command_buffer().expect("command buffer");
46 let result = graph
47 .encode(&command_buffer, &[&source])
48 .expect("graph encode");
49 command_buffer.commit();
50 command_buffer.wait_until_completed();
51
52 let output = result.read_f32().expect("read graph output");
53 let expected = [0.0_f32, 0.5, 2.0, 0.0];
54 for (actual, expected_value) in output.iter().zip(expected) {
55 assert!(
56 (actual - expected_value).abs() < 1.0e-4,
57 "unexpected relu graph output: {output:?}"
58 );
59 }
60
61 let convolution = CnnConvolutionDescriptor::new(3, 3, 1, 4).expect("convolution descriptor");
62 convolution.set_stride_in_pixels_x(2);
63 convolution.set_stride_in_pixels_y(1);
64 convolution.set_groups(1);
65 convolution.set_dilation_rate_x(1);
66 convolution.set_dilation_rate_y(2);
67 assert_eq!(convolution.kernel_width(), 3);
68 assert_eq!(convolution.kernel_height(), 3);
69 assert_eq!(convolution.stride_in_pixels_x(), 2);
70 assert_eq!(convolution.stride_in_pixels_y(), 1);
71 assert_eq!(convolution.groups(), 1);
72 assert_eq!(convolution.dilation_rate_x(), 1);
73 assert_eq!(convolution.dilation_rate_y(), 2);
74
75 let rnn = RnnSingleGateDescriptor::new(3, 5).expect("rnn descriptor");
76 rnn.set_use_layer_input_unit_transform_mode(true);
77 rnn.set_use_float32_weights(true);
78 rnn.set_layer_sequence_direction(rnn_sequence_direction::BACKWARD);
79 assert_eq!(rnn.input_feature_channels(), 3);
80 assert_eq!(rnn.output_feature_channels(), 5);
81 assert!(rnn.use_layer_input_unit_transform_mode());
82 assert!(rnn.use_float32_weights());
83 assert_eq!(
84 rnn.layer_sequence_direction(),
85 rnn_sequence_direction::BACKWARD,
86 "expected backward RNN sequence direction"
87 );
88
89 println!(
90 "nn smoke passed: relu={output:?} source_images={}",
91 graph.source_image_count()
92 );
93}Source§impl CnnSoftMaxNode
impl CnnSoftMaxNode
Sourcepub fn result_image(&self) -> Option<NNImageNode>
pub fn result_image(&self) -> Option<NNImageNode>
Examples found in repository?
examples/05_nn_graph_relu.rs (line 26)
8fn main() {
9 let device = MetalDevice::system_default().expect("no Metal device available");
10 let queue = device.new_command_queue().expect("command queue");
11
12 let input_node = NNImageNode::new().expect("input node");
13 input_node.set_format(feature_channel_format::FLOAT32);
14 let relu = CnnNeuronReluNode::new(&input_node, 0.0).expect("relu node");
15 let relu_result = relu.result_image().expect("relu result image");
16 relu_result.set_format(feature_channel_format::FLOAT32);
17 relu_result.set_synchronize_resource(true);
18 relu_result.use_default_allocator();
19 let pooling = CnnPoolingMaxNode::new(&relu_result, 2, 2).expect("pooling node");
20 assert!(
21 pooling.result_image().is_some(),
22 "pooling result image should exist"
23 );
24 let softmax = CnnSoftMaxNode::new(&relu_result).expect("softmax node");
25 assert!(
26 softmax.result_image().is_some(),
27 "softmax result image should exist"
28 );
29 let upsampling = CnnUpsamplingNearestNode::new(&relu_result, 2, 2).expect("upsampling node");
30 assert!(
31 upsampling.result_image().is_some(),
32 "upsampling result image should exist"
33 );
34
35 let graph = NNGraph::new(&device, &relu_result, true).expect("graph");
36 graph.set_format(feature_channel_format::FLOAT32);
37 graph.use_default_destination_image_allocator();
38
39 let descriptor = ImageDescriptor::new(2, 2, 1, feature_channel_format::FLOAT32);
40 let source = Image::new(&device, descriptor).expect("source image");
41 source
42 .write_f32(&[-1.0, 0.5, 2.0, -0.25])
43 .expect("write source image");
44
45 let command_buffer = queue.new_command_buffer().expect("command buffer");
46 let result = graph
47 .encode(&command_buffer, &[&source])
48 .expect("graph encode");
49 command_buffer.commit();
50 command_buffer.wait_until_completed();
51
52 let output = result.read_f32().expect("read graph output");
53 let expected = [0.0_f32, 0.5, 2.0, 0.0];
54 for (actual, expected_value) in output.iter().zip(expected) {
55 assert!(
56 (actual - expected_value).abs() < 1.0e-4,
57 "unexpected relu graph output: {output:?}"
58 );
59 }
60
61 let convolution = CnnConvolutionDescriptor::new(3, 3, 1, 4).expect("convolution descriptor");
62 convolution.set_stride_in_pixels_x(2);
63 convolution.set_stride_in_pixels_y(1);
64 convolution.set_groups(1);
65 convolution.set_dilation_rate_x(1);
66 convolution.set_dilation_rate_y(2);
67 assert_eq!(convolution.kernel_width(), 3);
68 assert_eq!(convolution.kernel_height(), 3);
69 assert_eq!(convolution.stride_in_pixels_x(), 2);
70 assert_eq!(convolution.stride_in_pixels_y(), 1);
71 assert_eq!(convolution.groups(), 1);
72 assert_eq!(convolution.dilation_rate_x(), 1);
73 assert_eq!(convolution.dilation_rate_y(), 2);
74
75 let rnn = RnnSingleGateDescriptor::new(3, 5).expect("rnn descriptor");
76 rnn.set_use_layer_input_unit_transform_mode(true);
77 rnn.set_use_float32_weights(true);
78 rnn.set_layer_sequence_direction(rnn_sequence_direction::BACKWARD);
79 assert_eq!(rnn.input_feature_channels(), 3);
80 assert_eq!(rnn.output_feature_channels(), 5);
81 assert!(rnn.use_layer_input_unit_transform_mode());
82 assert!(rnn.use_float32_weights());
83 assert_eq!(
84 rnn.layer_sequence_direction(),
85 rnn_sequence_direction::BACKWARD,
86 "expected backward RNN sequence direction"
87 );
88
89 println!(
90 "nn smoke passed: relu={output:?} source_images={}",
91 graph.source_image_count()
92 );
93}Trait Implementations§
Source§impl Drop for CnnSoftMaxNode
impl Drop for CnnSoftMaxNode
impl Send for CnnSoftMaxNode
impl Sync for CnnSoftMaxNode
Auto Trait Implementations§
impl Freeze for CnnSoftMaxNode
impl RefUnwindSafe for CnnSoftMaxNode
impl Unpin for CnnSoftMaxNode
impl UnsafeUnpin for CnnSoftMaxNode
impl UnwindSafe for CnnSoftMaxNode
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more