Skip to main content

apple_mps/
neural.rs

1use crate::ffi;
2use crate::image::Image;
3use apple_metal::{CommandBuffer, MetalDevice};
4use core::ffi::c_void;
5use core::ptr;
6
7/// `MPSRNNSequenceDirection` constants.
8pub mod rnn_sequence_direction {
9    pub const FORWARD: usize = 0;
10    pub const BACKWARD: usize = 1;
11}
12
13macro_rules! opaque_handle {
14    ($name:ident) => {
15        pub struct $name {
16            ptr: *mut c_void,
17        }
18
19        unsafe impl Send for $name {}
20        unsafe impl Sync for $name {}
21
22        impl Drop for $name {
23            fn drop(&mut self) {
24                if !self.ptr.is_null() {
25                    unsafe { ffi::mps_object_release(self.ptr) };
26                    self.ptr = ptr::null_mut();
27                }
28            }
29        }
30
31        impl $name {
32            #[must_use]
33            pub const fn as_ptr(&self) -> *mut c_void {
34                self.ptr
35            }
36        }
37    };
38}
39
40macro_rules! impl_filter_result_image {
41    ($name:ident) => {
42        impl $name {
43            #[must_use]
44            pub fn result_image(&self) -> Option<NNImageNode> {
45                let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
46                if ptr.is_null() {
47                    None
48                } else {
49                    Some(NNImageNode { ptr })
50                }
51            }
52        }
53    };
54}
55
56opaque_handle!(NNImageNode);
57impl NNImageNode {
58    #[must_use]
59    pub fn new() -> Option<Self> {
60        let ptr = unsafe { ffi::mps_nn_image_node_new() };
61        if ptr.is_null() {
62            None
63        } else {
64            Some(Self { ptr })
65        }
66    }
67
68    #[must_use]
69    pub fn exported() -> Option<Self> {
70        let ptr = unsafe { ffi::mps_nn_image_node_exported() };
71        if ptr.is_null() {
72            None
73        } else {
74            Some(Self { ptr })
75        }
76    }
77
78    #[must_use]
79    pub fn format(&self) -> usize {
80        unsafe { ffi::mps_nn_image_node_format(self.ptr) }
81    }
82
83    pub fn set_format(&self, format: usize) {
84        unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
85    }
86
87    #[must_use]
88    pub fn export_from_graph(&self) -> bool {
89        unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
90    }
91
92    pub fn set_export_from_graph(&self, export: bool) {
93        unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
94    }
95
96    #[must_use]
97    pub fn synchronize_resource(&self) -> bool {
98        unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
99    }
100
101    pub fn set_synchronize_resource(&self, synchronize: bool) {
102        unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
103    }
104
105    pub fn use_default_allocator(&self) {
106        unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
107    }
108}
109
110opaque_handle!(CnnNeuronReluNode);
111impl CnnNeuronReluNode {
112    #[must_use]
113    pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
114        let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
115        if ptr.is_null() {
116            None
117        } else {
118            Some(Self { ptr })
119        }
120    }
121}
122impl_filter_result_image!(CnnNeuronReluNode);
123
124opaque_handle!(CnnPoolingMaxNode);
125impl CnnPoolingMaxNode {
126    #[must_use]
127    pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
128        let ptr =
129            unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
130        if ptr.is_null() {
131            None
132        } else {
133            Some(Self { ptr })
134        }
135    }
136}
137impl_filter_result_image!(CnnPoolingMaxNode);
138
139opaque_handle!(CnnSoftMaxNode);
140impl CnnSoftMaxNode {
141    #[must_use]
142    pub fn new(source: &NNImageNode) -> Option<Self> {
143        let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
144        if ptr.is_null() {
145            None
146        } else {
147            Some(Self { ptr })
148        }
149    }
150}
151impl_filter_result_image!(CnnSoftMaxNode);
152
153opaque_handle!(CnnUpsamplingNearestNode);
154impl CnnUpsamplingNearestNode {
155    #[must_use]
156    pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
157        let ptr =
158            unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
159        if ptr.is_null() {
160            None
161        } else {
162            Some(Self { ptr })
163        }
164    }
165}
166impl_filter_result_image!(CnnUpsamplingNearestNode);
167
168opaque_handle!(NNGraph);
169impl NNGraph {
170    #[must_use]
171    pub fn new(
172        device: &MetalDevice,
173        result_image: &NNImageNode,
174        result_image_is_needed: bool,
175    ) -> Option<Self> {
176        let ptr = unsafe {
177            ffi::mps_nn_graph_new(
178                device.as_ptr(),
179                result_image.as_ptr(),
180                result_image_is_needed,
181            )
182        };
183        if ptr.is_null() {
184            None
185        } else {
186            Some(Self { ptr })
187        }
188    }
189
190    #[must_use]
191    pub fn source_image_count(&self) -> usize {
192        unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
193    }
194
195    #[must_use]
196    pub fn format(&self) -> usize {
197        unsafe { ffi::mps_nn_graph_format(self.ptr) }
198    }
199
200    pub fn set_format(&self, format: usize) {
201        unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
202    }
203
204    pub fn set_output_state_is_temporary(&self, temporary: bool) {
205        unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
206    }
207
208    pub fn use_default_destination_image_allocator(&self) {
209        unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
210    }
211
212    pub fn reload_from_data_sources(&self) {
213        unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
214    }
215
216    #[must_use]
217    pub fn encode(
218        &self,
219        command_buffer: &CommandBuffer,
220        source_images: &[&Image],
221    ) -> Option<Image> {
222        let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
223        let source_handles = if handles.is_empty() {
224            ptr::null()
225        } else {
226            handles.as_ptr()
227        };
228        let ptr = unsafe {
229            ffi::mps_nn_graph_encode(
230                self.ptr,
231                command_buffer.as_ptr(),
232                source_images.len(),
233                source_handles,
234            )
235        };
236        if ptr.is_null() {
237            None
238        } else {
239            Some(unsafe { Image::from_raw(ptr) })
240        }
241    }
242}
243
244opaque_handle!(CnnConvolutionDescriptor);
245impl CnnConvolutionDescriptor {
246    #[must_use]
247    pub fn new(
248        kernel_width: usize,
249        kernel_height: usize,
250        input_feature_channels: usize,
251        output_feature_channels: usize,
252    ) -> Option<Self> {
253        let ptr = unsafe {
254            ffi::mps_cnn_convolution_descriptor_new(
255                kernel_width,
256                kernel_height,
257                input_feature_channels,
258                output_feature_channels,
259            )
260        };
261        if ptr.is_null() {
262            None
263        } else {
264            Some(Self { ptr })
265        }
266    }
267
268    #[must_use]
269    pub fn kernel_width(&self) -> usize {
270        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
271    }
272
273    #[must_use]
274    pub fn kernel_height(&self) -> usize {
275        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
276    }
277
278    #[must_use]
279    pub fn stride_in_pixels_x(&self) -> usize {
280        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
281    }
282
283    pub fn set_stride_in_pixels_x(&self, value: usize) {
284        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
285    }
286
287    #[must_use]
288    pub fn stride_in_pixels_y(&self) -> usize {
289        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
290    }
291
292    pub fn set_stride_in_pixels_y(&self, value: usize) {
293        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
294    }
295
296    #[must_use]
297    pub fn groups(&self) -> usize {
298        unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
299    }
300
301    pub fn set_groups(&self, value: usize) {
302        unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
303    }
304
305    #[must_use]
306    pub fn dilation_rate_x(&self) -> usize {
307        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
308    }
309
310    pub fn set_dilation_rate_x(&self, value: usize) {
311        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
312    }
313
314    #[must_use]
315    pub fn dilation_rate_y(&self) -> usize {
316        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
317    }
318
319    pub fn set_dilation_rate_y(&self, value: usize) {
320        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
321    }
322}
323
324opaque_handle!(RnnSingleGateDescriptor);
325impl RnnSingleGateDescriptor {
326    #[must_use]
327    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
328        let ptr = unsafe {
329            ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
330        };
331        if ptr.is_null() {
332            None
333        } else {
334            Some(Self { ptr })
335        }
336    }
337
338    #[must_use]
339    pub fn input_feature_channels(&self) -> usize {
340        unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
341    }
342
343    pub fn set_input_feature_channels(&self, value: usize) {
344        unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
345    }
346
347    #[must_use]
348    pub fn output_feature_channels(&self) -> usize {
349        unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
350    }
351
352    pub fn set_output_feature_channels(&self, value: usize) {
353        unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
354    }
355
356    #[must_use]
357    pub fn use_layer_input_unit_transform_mode(&self) -> bool {
358        unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
359    }
360
361    pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
362        unsafe {
363            ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
364                self.ptr, value,
365            );
366        };
367    }
368
369    #[must_use]
370    pub fn use_float32_weights(&self) -> bool {
371        unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
372    }
373
374    pub fn set_use_float32_weights(&self, value: bool) {
375        unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
376    }
377
378    #[must_use]
379    pub fn layer_sequence_direction(&self) -> usize {
380        unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
381    }
382
383    pub fn set_layer_sequence_direction(&self, value: usize) {
384        unsafe {
385            ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
386        };
387    }
388}