1use crate::ffi;
2use crate::image::Image;
3use apple_metal::{CommandBuffer, MetalDevice};
4use core::ffi::c_void;
5use core::ptr;
6
7pub mod rnn_sequence_direction {
9 pub const FORWARD: usize = 0;
10 pub const BACKWARD: usize = 1;
11}
12
13macro_rules! opaque_handle {
14 ($name:ident) => {
15 pub struct $name {
16 ptr: *mut c_void,
17 }
18
19 unsafe impl Send for $name {}
20 unsafe impl Sync for $name {}
21
22 impl Drop for $name {
23 fn drop(&mut self) {
24 if !self.ptr.is_null() {
25 unsafe { ffi::mps_object_release(self.ptr) };
26 self.ptr = ptr::null_mut();
27 }
28 }
29 }
30
31 impl $name {
32 #[must_use]
33 pub const fn as_ptr(&self) -> *mut c_void {
34 self.ptr
35 }
36 }
37 };
38}
39
40macro_rules! impl_filter_result_image {
41 ($name:ident) => {
42 impl $name {
43 #[must_use]
44 pub fn result_image(&self) -> Option<NNImageNode> {
45 let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
46 if ptr.is_null() {
47 None
48 } else {
49 Some(NNImageNode { ptr })
50 }
51 }
52 }
53 };
54}
55
56opaque_handle!(NNImageNode);
57impl NNImageNode {
58 #[must_use]
59 pub fn new() -> Option<Self> {
60 let ptr = unsafe { ffi::mps_nn_image_node_new() };
61 if ptr.is_null() {
62 None
63 } else {
64 Some(Self { ptr })
65 }
66 }
67
68 #[must_use]
69 pub fn exported() -> Option<Self> {
70 let ptr = unsafe { ffi::mps_nn_image_node_exported() };
71 if ptr.is_null() {
72 None
73 } else {
74 Some(Self { ptr })
75 }
76 }
77
78 #[must_use]
79 pub fn format(&self) -> usize {
80 unsafe { ffi::mps_nn_image_node_format(self.ptr) }
81 }
82
83 pub fn set_format(&self, format: usize) {
84 unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
85 }
86
87 #[must_use]
88 pub fn export_from_graph(&self) -> bool {
89 unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
90 }
91
92 pub fn set_export_from_graph(&self, export: bool) {
93 unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
94 }
95
96 #[must_use]
97 pub fn synchronize_resource(&self) -> bool {
98 unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
99 }
100
101 pub fn set_synchronize_resource(&self, synchronize: bool) {
102 unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
103 }
104
105 pub fn use_default_allocator(&self) {
106 unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
107 }
108}
109
110opaque_handle!(CnnNeuronReluNode);
111impl CnnNeuronReluNode {
112 #[must_use]
113 pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
114 let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
115 if ptr.is_null() {
116 None
117 } else {
118 Some(Self { ptr })
119 }
120 }
121}
122impl_filter_result_image!(CnnNeuronReluNode);
123
124opaque_handle!(CnnPoolingMaxNode);
125impl CnnPoolingMaxNode {
126 #[must_use]
127 pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
128 let ptr =
129 unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
130 if ptr.is_null() {
131 None
132 } else {
133 Some(Self { ptr })
134 }
135 }
136}
137impl_filter_result_image!(CnnPoolingMaxNode);
138
139opaque_handle!(CnnSoftMaxNode);
140impl CnnSoftMaxNode {
141 #[must_use]
142 pub fn new(source: &NNImageNode) -> Option<Self> {
143 let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
144 if ptr.is_null() {
145 None
146 } else {
147 Some(Self { ptr })
148 }
149 }
150}
151impl_filter_result_image!(CnnSoftMaxNode);
152
153opaque_handle!(CnnUpsamplingNearestNode);
154impl CnnUpsamplingNearestNode {
155 #[must_use]
156 pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
157 let ptr =
158 unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
159 if ptr.is_null() {
160 None
161 } else {
162 Some(Self { ptr })
163 }
164 }
165}
166impl_filter_result_image!(CnnUpsamplingNearestNode);
167
168opaque_handle!(NNGraph);
169impl NNGraph {
170 #[must_use]
171 pub fn new(
172 device: &MetalDevice,
173 result_image: &NNImageNode,
174 result_image_is_needed: bool,
175 ) -> Option<Self> {
176 let ptr = unsafe {
177 ffi::mps_nn_graph_new(
178 device.as_ptr(),
179 result_image.as_ptr(),
180 result_image_is_needed,
181 )
182 };
183 if ptr.is_null() {
184 None
185 } else {
186 Some(Self { ptr })
187 }
188 }
189
190 #[must_use]
191 pub fn source_image_count(&self) -> usize {
192 unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
193 }
194
195 #[must_use]
196 pub fn format(&self) -> usize {
197 unsafe { ffi::mps_nn_graph_format(self.ptr) }
198 }
199
200 pub fn set_format(&self, format: usize) {
201 unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
202 }
203
204 pub fn set_output_state_is_temporary(&self, temporary: bool) {
205 unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
206 }
207
208 pub fn use_default_destination_image_allocator(&self) {
209 unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
210 }
211
212 pub fn reload_from_data_sources(&self) {
213 unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
214 }
215
216 #[must_use]
217 pub fn encode(
218 &self,
219 command_buffer: &CommandBuffer,
220 source_images: &[&Image],
221 ) -> Option<Image> {
222 let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
223 let source_handles = if handles.is_empty() {
224 ptr::null()
225 } else {
226 handles.as_ptr()
227 };
228 let ptr = unsafe {
229 ffi::mps_nn_graph_encode(
230 self.ptr,
231 command_buffer.as_ptr(),
232 source_images.len(),
233 source_handles,
234 )
235 };
236 if ptr.is_null() {
237 None
238 } else {
239 Some(unsafe { Image::from_raw(ptr) })
240 }
241 }
242}
243
244opaque_handle!(CnnConvolutionDescriptor);
245impl CnnConvolutionDescriptor {
246 #[must_use]
247 pub fn new(
248 kernel_width: usize,
249 kernel_height: usize,
250 input_feature_channels: usize,
251 output_feature_channels: usize,
252 ) -> Option<Self> {
253 let ptr = unsafe {
254 ffi::mps_cnn_convolution_descriptor_new(
255 kernel_width,
256 kernel_height,
257 input_feature_channels,
258 output_feature_channels,
259 )
260 };
261 if ptr.is_null() {
262 None
263 } else {
264 Some(Self { ptr })
265 }
266 }
267
268 #[must_use]
269 pub fn kernel_width(&self) -> usize {
270 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
271 }
272
273 #[must_use]
274 pub fn kernel_height(&self) -> usize {
275 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
276 }
277
278 #[must_use]
279 pub fn stride_in_pixels_x(&self) -> usize {
280 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
281 }
282
283 pub fn set_stride_in_pixels_x(&self, value: usize) {
284 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
285 }
286
287 #[must_use]
288 pub fn stride_in_pixels_y(&self) -> usize {
289 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
290 }
291
292 pub fn set_stride_in_pixels_y(&self, value: usize) {
293 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
294 }
295
296 #[must_use]
297 pub fn groups(&self) -> usize {
298 unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
299 }
300
301 pub fn set_groups(&self, value: usize) {
302 unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
303 }
304
305 #[must_use]
306 pub fn dilation_rate_x(&self) -> usize {
307 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
308 }
309
310 pub fn set_dilation_rate_x(&self, value: usize) {
311 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
312 }
313
314 #[must_use]
315 pub fn dilation_rate_y(&self) -> usize {
316 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
317 }
318
319 pub fn set_dilation_rate_y(&self, value: usize) {
320 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
321 }
322}
323
324opaque_handle!(RnnSingleGateDescriptor);
325impl RnnSingleGateDescriptor {
326 #[must_use]
327 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
328 let ptr = unsafe {
329 ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
330 };
331 if ptr.is_null() {
332 None
333 } else {
334 Some(Self { ptr })
335 }
336 }
337
338 #[must_use]
339 pub fn input_feature_channels(&self) -> usize {
340 unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
341 }
342
343 pub fn set_input_feature_channels(&self, value: usize) {
344 unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
345 }
346
347 #[must_use]
348 pub fn output_feature_channels(&self) -> usize {
349 unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
350 }
351
352 pub fn set_output_feature_channels(&self, value: usize) {
353 unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
354 }
355
356 #[must_use]
357 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
358 unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
359 }
360
361 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
362 unsafe {
363 ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
364 self.ptr, value,
365 );
366 };
367 }
368
369 #[must_use]
370 pub fn use_float32_weights(&self) -> bool {
371 unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
372 }
373
374 pub fn set_use_float32_weights(&self, value: bool) {
375 unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
376 }
377
378 #[must_use]
379 pub fn layer_sequence_direction(&self) -> usize {
380 unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
381 }
382
383 pub fn set_layer_sequence_direction(&self, value: usize) {
384 unsafe {
385 ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
386 };
387 }
388}