Skip to main content

apple_mps/
neural.rs

1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8/// `MPSRNNSequenceDirection` constants.
9pub mod rnn_sequence_direction {
10    /// Wraps a `MPSRNNSequenceDirection` raw value.
11    pub const FORWARD: usize = 0;
12    /// Wraps a `MPSRNNSequenceDirection` raw value.
13    pub const BACKWARD: usize = 1;
14}
15
16/// `MPSRNNBidirectionalCombineMode` constants.
17pub mod rnn_bidirectional_combine_mode {
18    /// Wraps a `MPSRNNBidirectionalCombineMode` raw value.
19    pub const NONE: usize = 0;
20    /// Wraps a `MPSRNNBidirectionalCombineMode` raw value.
21    pub const ADD: usize = 1;
22    /// Wraps a `MPSRNNBidirectionalCombineMode` raw value.
23    pub const CONCATENATE: usize = 2;
24}
25
26/// `MPSCNNConvolutionFlags` constants.
27pub mod cnn_convolution_flags {
28    /// Wraps a `MPSCNNConvolutionFlags` raw value.
29    pub const NONE: usize = 0;
30}
31
32/// `MPSCNNConvolutionWeightsLayout` constants.
33pub mod cnn_convolution_weights_layout {
34    /// Wraps a `MPSCNNConvolutionWeightsLayout` raw value.
35    pub const OHWI: u32 = 0;
36}
37
38/// `MPSNNConvolutionAccumulatorPrecisionOption` constants.
39pub mod cnn_accumulator_precision_option {
40    /// Wraps a `MPSNNConvolutionAccumulatorPrecisionOption` raw value.
41    pub const HALF: usize = 0;
42    /// Wraps a `MPSNNConvolutionAccumulatorPrecisionOption` raw value.
43    pub const FLOAT: usize = 1;
44}
45
46/// `MPSNNRegularizationType` constants.
47pub mod nn_regularization_type {
48    /// Wraps a `MPSNNRegularizationType` raw value.
49    pub const NONE: usize = 0;
50    /// Wraps a `MPSNNRegularizationType` raw value.
51    pub const L1: usize = 1;
52    /// Wraps a `MPSNNRegularizationType` raw value.
53    pub const L2: usize = 2;
54}
55
56#[doc(hidden)]
57pub use crate::generated::neural::*;
58
59macro_rules! opaque_handle {
60    ($name:ident, $doc:expr) => {
61        #[doc = $doc]
62        pub struct $name {
63            ptr: *mut c_void,
64        }
65
66        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
67        unsafe impl Send for $name {}
68        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
69        unsafe impl Sync for $name {}
70
71        impl Drop for $name {
72            fn drop(&mut self) {
73                if !self.ptr.is_null() {
74                    // SAFETY: `ptr` is a +1 retained MPS object owned by this wrapper.
75                    unsafe { ffi::mps_object_release(self.ptr) };
76                    self.ptr = ptr::null_mut();
77                }
78            }
79        }
80
81        impl $name {
82            /// Returns the retained Objective-C pointer backing this wrapper.
83            #[must_use]
84            pub const fn as_ptr(&self) -> *mut c_void {
85                self.ptr
86            }
87        }
88    };
89}
90
91macro_rules! impl_filter_result_image {
92    ($name:ident) => {
93        impl $name {
94            /// Wraps the corresponding Metal Performance Shaders method.
95            #[must_use]
96            pub fn result_image(&self) -> Option<NNImageNode> {
97                let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
98                if ptr.is_null() {
99                    None
100                } else {
101                    Some(NNImageNode { ptr })
102                }
103            }
104        }
105    };
106}
107
108fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
109    let retained = unsafe { ffi::mps_object_retain(ptr) };
110    if retained.is_null() {
111        None
112    } else {
113        Some(retained)
114    }
115}
116
117macro_rules! impl_rnn_descriptor_common {
118    ($name:ident) => {
119        impl $name {
120            /// Wraps the corresponding Metal Performance Shaders method.
121            #[must_use]
122            pub fn input_feature_channels(&self) -> usize {
123                unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
124            }
125
126            /// Wraps the corresponding Metal Performance Shaders method.
127            pub fn set_input_feature_channels(&self, value: usize) {
128                unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
129            }
130
131            /// Wraps the corresponding Metal Performance Shaders method.
132            #[must_use]
133            pub fn output_feature_channels(&self) -> usize {
134                unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
135            }
136
137            /// Wraps the corresponding Metal Performance Shaders method.
138            pub fn set_output_feature_channels(&self, value: usize) {
139                unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
140            }
141
142            /// Wraps the corresponding Metal Performance Shaders method.
143            #[must_use]
144            pub fn use_layer_input_unit_transform_mode(&self) -> bool {
145                unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
146            }
147
148            /// Wraps the corresponding Metal Performance Shaders method.
149            pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
150                unsafe {
151                    ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
152                };
153            }
154
155            /// Wraps the corresponding Metal Performance Shaders method.
156            #[must_use]
157            pub fn use_float32_weights(&self) -> bool {
158                unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
159            }
160
161            /// Wraps the corresponding Metal Performance Shaders method.
162            pub fn set_use_float32_weights(&self, value: bool) {
163                unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
164            }
165
166            /// Wraps the corresponding Metal Performance Shaders method.
167            #[must_use]
168            pub fn layer_sequence_direction(&self) -> usize {
169                unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
170            }
171
172            /// Wraps the corresponding Metal Performance Shaders method.
173            pub fn set_layer_sequence_direction(&self, value: usize) {
174                unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
175            }
176        }
177    };
178}
179
180macro_rules! impl_optimizer_common {
181    ($name:ident) => {
182        impl $name {
183            /// Wraps the corresponding Metal Performance Shaders method.
184            #[must_use]
185            pub fn learning_rate(&self) -> f32 {
186                unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
187            }
188
189            /// Wraps the corresponding Metal Performance Shaders method.
190            pub fn set_learning_rate(&self, value: f32) {
191                unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
192            }
193
194            /// Wraps the corresponding Metal Performance Shaders method.
195            #[must_use]
196            pub fn gradient_rescale(&self) -> f32 {
197                unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
198            }
199
200            /// Wraps the corresponding Metal Performance Shaders method.
201            #[must_use]
202            pub fn apply_gradient_clipping(&self) -> bool {
203                unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
204            }
205
206            /// Wraps the corresponding Metal Performance Shaders method.
207            pub fn set_apply_gradient_clipping(&self, value: bool) {
208                unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
209            }
210
211            /// Wraps the corresponding Metal Performance Shaders method.
212            #[must_use]
213            pub fn gradient_clip_max(&self) -> f32 {
214                unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
215            }
216
217            /// Wraps the corresponding Metal Performance Shaders method.
218            #[must_use]
219            pub fn gradient_clip_min(&self) -> f32 {
220                unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
221            }
222
223            /// Wraps the corresponding Metal Performance Shaders method.
224            #[must_use]
225            pub fn regularization_scale(&self) -> f32 {
226                unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
227            }
228
229            /// Wraps the corresponding Metal Performance Shaders method.
230            #[must_use]
231            pub fn regularization_type(&self) -> usize {
232                unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
233            }
234        }
235    };
236}
237
238opaque_handle!(NNImageNode, "Wraps `MPSNNImageNode`.");
239impl NNImageNode {
240    /// Wraps a constructor on `MPSNNImageNode`.
241    #[must_use]
242    pub fn new() -> Option<Self> {
243        let ptr = unsafe { ffi::mps_nn_image_node_new() };
244        if ptr.is_null() {
245            None
246        } else {
247            Some(Self { ptr })
248        }
249    }
250
251    /// Wraps a constructor on `MPSNNImageNode`.
252    #[must_use]
253    pub fn exported() -> Option<Self> {
254        let ptr = unsafe { ffi::mps_nn_image_node_exported() };
255        if ptr.is_null() {
256            None
257        } else {
258            Some(Self { ptr })
259        }
260    }
261
262    /// Wraps the corresponding `MPSNNImageNode` method.
263    #[must_use]
264    pub fn format(&self) -> usize {
265        unsafe { ffi::mps_nn_image_node_format(self.ptr) }
266    }
267
268    /// Wraps the corresponding `MPSNNImageNode` setter.
269    pub fn set_format(&self, format: usize) {
270        unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
271    }
272
273    /// Wraps the corresponding `MPSNNImageNode` method.
274    #[must_use]
275    pub fn export_from_graph(&self) -> bool {
276        unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
277    }
278
279    /// Wraps the corresponding `MPSNNImageNode` setter.
280    pub fn set_export_from_graph(&self, export: bool) {
281        unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
282    }
283
284    /// Wraps the corresponding `MPSNNImageNode` method.
285    #[must_use]
286    pub fn synchronize_resource(&self) -> bool {
287        unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
288    }
289
290    /// Wraps the corresponding `MPSNNImageNode` setter.
291    pub fn set_synchronize_resource(&self, synchronize: bool) {
292        unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
293    }
294
295    /// Wraps the corresponding `MPSNNImageNode` method.
296    pub fn use_default_allocator(&self) {
297        unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
298    }
299}
300
301opaque_handle!(CnnNeuronReluNode, "Wraps `MPSCNNNeuronReLUNode`.");
302impl CnnNeuronReluNode {
303    /// Wraps a constructor on `MPSCNNNeuronReLUNode`.
304    #[must_use]
305    pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
306        let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
307        if ptr.is_null() {
308            None
309        } else {
310            Some(Self { ptr })
311        }
312    }
313}
314impl_filter_result_image!(CnnNeuronReluNode);
315
316opaque_handle!(CnnPoolingMaxNode, "Wraps `MPSCNNPoolingMaxNode`.");
317impl CnnPoolingMaxNode {
318    /// Wraps a constructor on `MPSCNNPoolingMaxNode`.
319    #[must_use]
320    pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
321        let ptr =
322            unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
323        if ptr.is_null() {
324            None
325        } else {
326            Some(Self { ptr })
327        }
328    }
329}
330impl_filter_result_image!(CnnPoolingMaxNode);
331
332opaque_handle!(CnnSoftMaxNode, "Wraps `MPSCNNSoftMaxNode`.");
333impl CnnSoftMaxNode {
334    /// Wraps a constructor on `MPSCNNSoftMax`.
335    #[must_use]
336    pub fn new(source: &NNImageNode) -> Option<Self> {
337        let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
338        if ptr.is_null() {
339            None
340        } else {
341            Some(Self { ptr })
342        }
343    }
344}
345impl_filter_result_image!(CnnSoftMaxNode);
346
347opaque_handle!(CnnUpsamplingNearestNode, "Wraps `MPSCNNUpsamplingNearestNode`.");
348impl CnnUpsamplingNearestNode {
349    /// Wraps a constructor on `MPSCNNUpsamplingNearestNode`.
350    #[must_use]
351    pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
352        let ptr =
353            unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
354        if ptr.is_null() {
355            None
356        } else {
357            Some(Self { ptr })
358        }
359    }
360}
361impl_filter_result_image!(CnnUpsamplingNearestNode);
362
363opaque_handle!(NNGraph, "Wraps `MPSNNGraph`.");
364impl NNGraph {
365    /// Wraps a constructor on `MPSNNGraph`.
366    #[must_use]
367    pub fn new(
368        device: &MetalDevice,
369        result_image: &NNImageNode,
370        result_image_is_needed: bool,
371    ) -> Option<Self> {
372        let ptr = unsafe {
373            ffi::mps_nn_graph_new(
374                device.as_ptr(),
375                result_image.as_ptr(),
376                result_image_is_needed,
377            )
378        };
379        if ptr.is_null() {
380            None
381        } else {
382            Some(Self { ptr })
383        }
384    }
385
386    /// Wraps the corresponding `MPSNNGraph` method.
387    #[must_use]
388    pub fn source_image_count(&self) -> usize {
389        unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
390    }
391
392    /// Wraps the corresponding `MPSNNGraph` method.
393    #[must_use]
394    pub fn format(&self) -> usize {
395        unsafe { ffi::mps_nn_graph_format(self.ptr) }
396    }
397
398    /// Wraps the corresponding `MPSNNGraph` setter.
399    pub fn set_format(&self, format: usize) {
400        unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
401    }
402
403    /// Wraps the corresponding `MPSNNGraph` setter.
404    pub fn set_output_state_is_temporary(&self, temporary: bool) {
405        unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
406    }
407
408    /// Wraps the corresponding `MPSNNGraph` method.
409    pub fn use_default_destination_image_allocator(&self) {
410        unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
411    }
412
413    /// Wraps the corresponding `MPSNNGraph` method.
414    pub fn reload_from_data_sources(&self) {
415        unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
416    }
417
418    /// Wraps the corresponding `MPSNNGraph` encode entry point.
419    #[must_use]
420    pub fn encode(
421        &self,
422        command_buffer: &CommandBuffer,
423        source_images: &[&Image],
424    ) -> Option<Image> {
425        let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
426        let source_handles = if handles.is_empty() {
427            ptr::null()
428        } else {
429            handles.as_ptr()
430        };
431        let ptr = unsafe {
432            ffi::mps_nn_graph_encode(
433                self.ptr,
434                command_buffer.as_ptr(),
435                source_images.len(),
436                source_handles,
437            )
438        };
439        if ptr.is_null() {
440            None
441        } else {
442            Some(unsafe { Image::from_raw(ptr) })
443        }
444    }
445}
446
447opaque_handle!(CnnConvolutionDescriptor, "Wraps `MPSCNNConvolutionDescriptor`.");
448impl CnnConvolutionDescriptor {
449    /// Wraps a constructor on `MPSCNNConvolutionDescriptor`.
450    #[must_use]
451    pub fn new(
452        kernel_width: usize,
453        kernel_height: usize,
454        input_feature_channels: usize,
455        output_feature_channels: usize,
456    ) -> Option<Self> {
457        let ptr = unsafe {
458            ffi::mps_cnn_convolution_descriptor_new(
459                kernel_width,
460                kernel_height,
461                input_feature_channels,
462                output_feature_channels,
463            )
464        };
465        if ptr.is_null() {
466            None
467        } else {
468            Some(Self { ptr })
469        }
470    }
471
472    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
473    #[must_use]
474    pub fn kernel_width(&self) -> usize {
475        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
476    }
477
478    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
479    #[must_use]
480    pub fn kernel_height(&self) -> usize {
481        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
482    }
483
484    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
485    #[must_use]
486    pub fn input_feature_channels(&self) -> usize {
487        unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
488    }
489
490    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
491    #[must_use]
492    pub fn output_feature_channels(&self) -> usize {
493        unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
494    }
495
496    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
497    #[must_use]
498    pub fn stride_in_pixels_x(&self) -> usize {
499        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
500    }
501
502    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` setter.
503    pub fn set_stride_in_pixels_x(&self, value: usize) {
504        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
505    }
506
507    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
508    #[must_use]
509    pub fn stride_in_pixels_y(&self) -> usize {
510        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
511    }
512
513    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` setter.
514    pub fn set_stride_in_pixels_y(&self, value: usize) {
515        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
516    }
517
518    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
519    #[must_use]
520    pub fn groups(&self) -> usize {
521        unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
522    }
523
524    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` setter.
525    pub fn set_groups(&self, value: usize) {
526        unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
527    }
528
529    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
530    #[must_use]
531    pub fn dilation_rate_x(&self) -> usize {
532        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
533    }
534
535    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` setter.
536    pub fn set_dilation_rate_x(&self, value: usize) {
537        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
538    }
539
540    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` method.
541    #[must_use]
542    pub fn dilation_rate_y(&self) -> usize {
543        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
544    }
545
546    /// Wraps the corresponding `MPSCNNConvolutionDescriptor` setter.
547    pub fn set_dilation_rate_y(&self, value: usize) {
548        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
549    }
550}
551
552opaque_handle!(RnnSingleGateDescriptor, "Wraps `MPSRNNSingleGateDescriptor`.");
553impl RnnSingleGateDescriptor {
554    /// Wraps a constructor on `MPSRNNSingleGateDescriptor`.
555    #[must_use]
556    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
557        let ptr = unsafe {
558            ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
559        };
560        if ptr.is_null() {
561            None
562        } else {
563            Some(Self { ptr })
564        }
565    }
566
567    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` method.
568    #[must_use]
569    pub fn input_feature_channels(&self) -> usize {
570        unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
571    }
572
573    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` setter.
574    pub fn set_input_feature_channels(&self, value: usize) {
575        unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
576    }
577
578    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` method.
579    #[must_use]
580    pub fn output_feature_channels(&self) -> usize {
581        unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
582    }
583
584    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` setter.
585    pub fn set_output_feature_channels(&self, value: usize) {
586        unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
587    }
588
589    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` method.
590    #[must_use]
591    pub fn use_layer_input_unit_transform_mode(&self) -> bool {
592        unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
593    }
594
595    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` setter.
596    pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
597        unsafe {
598            ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
599                self.ptr, value,
600            );
601        };
602    }
603
604    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` method.
605    #[must_use]
606    pub fn use_float32_weights(&self) -> bool {
607        unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
608    }
609
610    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` setter.
611    pub fn set_use_float32_weights(&self, value: bool) {
612        unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
613    }
614
615    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` method.
616    #[must_use]
617    pub fn layer_sequence_direction(&self) -> usize {
618        unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
619    }
620
621    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` setter.
622    pub fn set_layer_sequence_direction(&self, value: usize) {
623        unsafe {
624            ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
625        };
626    }
627
628    /// Wraps the corresponding `MPSRNNSingleGateDescriptor` conversion helper.
629    #[must_use]
630    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
631        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
632    }
633}
634
635opaque_handle!(CnnConvolution, "Wraps `MPSCNNConvolution`.");
636impl CnnConvolution {
637    /// Wraps a constructor on `MPSCNNConvolution`.
638    #[must_use]
639    pub fn new(
640        device: &MetalDevice,
641        descriptor: &CnnConvolutionDescriptor,
642        kernel_weights: &[f32],
643        bias_terms: Option<&[f32]>,
644        flags: usize,
645    ) -> Option<Self> {
646        if kernel_weights.is_empty() {
647            return None;
648        }
649        let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
650        let ptr = unsafe {
651            ffi::mps_cnn_convolution_new(
652                device.as_ptr(),
653                descriptor.as_ptr(),
654                kernel_weights.as_ptr(),
655                bias_terms_ptr,
656                flags,
657            )
658        };
659        if ptr.is_null() {
660            None
661        } else {
662            Some(Self { ptr })
663        }
664    }
665
666    /// Wraps the corresponding `MPSCNNConvolution` method.
667    #[must_use]
668    pub fn input_feature_channels(&self) -> usize {
669        unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
670    }
671
672    /// Wraps the corresponding `MPSCNNConvolution` method.
673    #[must_use]
674    pub fn output_feature_channels(&self) -> usize {
675        unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
676    }
677
678    /// Wraps the corresponding `MPSCNNConvolution` method.
679    #[must_use]
680    pub fn groups(&self) -> usize {
681        unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
682    }
683
684    /// Wraps the corresponding `MPSCNNConvolution` method.
685    #[must_use]
686    pub fn sub_pixel_scale_factor(&self) -> usize {
687        unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
688    }
689
690    /// Wraps the corresponding `MPSCNNConvolution` method.
691    #[must_use]
692    pub fn channel_multiplier(&self) -> usize {
693        unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
694    }
695
696    /// Wraps the corresponding `MPSCNNConvolution` method.
697    #[must_use]
698    pub fn accumulator_precision_option(&self) -> usize {
699        unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
700    }
701
702    /// Wraps the corresponding `MPSCNNConvolution` setter.
703    pub fn set_accumulator_precision_option(&self, value: usize) {
704        unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
705    }
706
707    /// Wraps the corresponding `MPSCNNConvolution` encode entry point.
708    pub fn encode_image(
709        &self,
710        command_buffer: &CommandBuffer,
711        source: &Image,
712        destination: &Image,
713    ) {
714        unsafe {
715            ffi::mps_cnn_convolution_encode_image(
716                self.ptr,
717                command_buffer.as_ptr(),
718                source.as_ptr(),
719                destination.as_ptr(),
720            );
721        };
722    }
723}
724
725opaque_handle!(CnnConvolutionWeightsAndBiasesState, "Wraps `MPSCNNConvolutionWeightsAndBiasesState`.");
726impl CnnConvolutionWeightsAndBiasesState {
727    /// Wraps a constructor on `MPSCNNConvolutionWeightsAndBiasesState`.
728    #[must_use]
729    pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
730        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
731        let ptr = unsafe {
732            ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr)
733        };
734        if ptr.is_null() {
735            None
736        } else {
737            Some(Self { ptr })
738        }
739    }
740
741    /// Wraps a constructor on `MPSCNNConvolutionWeightsAndBiasesState`.
742    #[must_use]
743    pub fn new_with_offsets(
744        weights: &MetalBuffer,
745        weights_offset: usize,
746        biases: Option<&MetalBuffer>,
747        biases_offset: usize,
748        descriptor: &CnnConvolutionDescriptor,
749    ) -> Option<Self> {
750        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
751        let ptr = unsafe {
752            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
753                weights.as_ptr(),
754                weights_offset,
755                biases_ptr,
756                biases_offset,
757                descriptor.as_ptr(),
758            )
759        };
760        if ptr.is_null() {
761            None
762        } else {
763            Some(Self { ptr })
764        }
765    }
766
767    /// Wraps a constructor on `MPSCNNConvolutionWeightsAndBiasesState`.
768    #[must_use]
769    pub fn new_with_device(
770        device: &MetalDevice,
771        descriptor: &CnnConvolutionDescriptor,
772    ) -> Option<Self> {
773        let ptr = unsafe {
774            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
775                device.as_ptr(),
776                descriptor.as_ptr(),
777            )
778        };
779        if ptr.is_null() {
780            None
781        } else {
782            Some(Self { ptr })
783        }
784    }
785
786    /// Wraps the corresponding `MPSCNNConvolutionWeightsAndBiasesState` method.
787    #[must_use]
788    pub fn weights_offset(&self) -> usize {
789        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
790    }
791
792    /// Wraps the corresponding `MPSCNNConvolutionWeightsAndBiasesState` method.
793    #[must_use]
794    pub fn biases_offset(&self) -> usize {
795        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
796    }
797}
798
799opaque_handle!(NNOptimizerDescriptor, "Wraps `MPSNNOptimizerDescriptor`.");
800impl NNOptimizerDescriptor {
801    /// Wraps a constructor on `MPSNNOptimizerDescriptor`.
802    #[must_use]
803    pub fn new(
804        learning_rate: f32,
805        gradient_rescale: f32,
806        regularization_type: usize,
807        regularization_scale: f32,
808    ) -> Option<Self> {
809        let ptr = unsafe {
810            ffi::mps_nn_optimizer_descriptor_new(
811                learning_rate,
812                gradient_rescale,
813                regularization_type,
814                regularization_scale,
815            )
816        };
817        if ptr.is_null() {
818            None
819        } else {
820            Some(Self { ptr })
821        }
822    }
823
824    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
825    #[must_use]
826    pub fn with_gradient_clipping(
827        learning_rate: f32,
828        gradient_rescale: f32,
829        apply_gradient_clipping: bool,
830        gradient_clip_max: f32,
831        gradient_clip_min: f32,
832        regularization_type: usize,
833        regularization_scale: f32,
834    ) -> Option<Self> {
835        let ptr = unsafe {
836            ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
837                learning_rate,
838                gradient_rescale,
839                apply_gradient_clipping,
840                gradient_clip_max,
841                gradient_clip_min,
842                regularization_type,
843                regularization_scale,
844            )
845        };
846        if ptr.is_null() {
847            None
848        } else {
849            Some(Self { ptr })
850        }
851    }
852
853    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
854    #[must_use]
855    pub fn learning_rate(&self) -> f32 {
856        unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
857    }
858
859    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
860    pub fn set_learning_rate(&self, value: f32) {
861        unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
862    }
863
864    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
865    #[must_use]
866    pub fn gradient_rescale(&self) -> f32 {
867        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
868    }
869
870    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
871    pub fn set_gradient_rescale(&self, value: f32) {
872        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
873    }
874
875    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
876    #[must_use]
877    pub fn apply_gradient_clipping(&self) -> bool {
878        unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
879    }
880
881    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
882    pub fn set_apply_gradient_clipping(&self, value: bool) {
883        unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
884    }
885
886    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
887    #[must_use]
888    pub fn gradient_clip_max(&self) -> f32 {
889        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
890    }
891
892    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
893    pub fn set_gradient_clip_max(&self, value: f32) {
894        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
895    }
896
897    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
898    #[must_use]
899    pub fn gradient_clip_min(&self) -> f32 {
900        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
901    }
902
903    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
904    pub fn set_gradient_clip_min(&self, value: f32) {
905        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
906    }
907
908    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
909    #[must_use]
910    pub fn regularization_scale(&self) -> f32 {
911        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
912    }
913
914    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
915    pub fn set_regularization_scale(&self, value: f32) {
916        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
917    }
918
919    /// Wraps the corresponding `MPSNNOptimizerDescriptor` method.
920    #[must_use]
921    pub fn regularization_type(&self) -> usize {
922        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
923    }
924
925    /// Wraps the corresponding `MPSNNOptimizerDescriptor` setter.
926    pub fn set_regularization_type(&self, value: usize) {
927        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
928    }
929}
930
931opaque_handle!(NNOptimizer, "Wraps `MPSNNOptimizer`.");
932impl_optimizer_common!(NNOptimizer);
933
934opaque_handle!(NNOptimizerStochasticGradientDescent, "Wraps `MPSNNOptimizerStochasticGradientDescent`.");
935impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
936impl NNOptimizerStochasticGradientDescent {
937    /// Wraps a constructor on `MPSNNOptimizerStochasticGradientDescent`.
938    #[must_use]
939    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
940        let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
941        if ptr.is_null() {
942            None
943        } else {
944            Some(Self { ptr })
945        }
946    }
947
948    /// Wraps a constructor on `MPSNNOptimizerStochasticGradientDescent`.
949    #[must_use]
950    pub fn new_with_options(
951        device: &MetalDevice,
952        momentum_scale: f32,
953        use_nesterov_momentum: bool,
954        optimizer_descriptor: &NNOptimizerDescriptor,
955    ) -> Option<Self> {
956        let ptr = unsafe {
957            ffi::mps_nn_optimizer_sgd_new_with_options(
958                device.as_ptr(),
959                momentum_scale,
960                use_nesterov_momentum,
961                optimizer_descriptor.as_ptr(),
962            )
963        };
964        if ptr.is_null() {
965            None
966        } else {
967            Some(Self { ptr })
968        }
969    }
970
971    /// Wraps the corresponding `MPSNNOptimizerStochasticGradientDescent` conversion helper.
972    #[must_use]
973    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
974        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
975    }
976
977    /// Wraps the corresponding `MPSNNOptimizerStochasticGradientDescent` method.
978    #[must_use]
979    pub fn momentum_scale(&self) -> f32 {
980        unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
981    }
982
983    /// Wraps the corresponding `MPSNNOptimizerStochasticGradientDescent` method.
984    #[must_use]
985    pub fn use_nesterov_momentum(&self) -> bool {
986        unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
987    }
988
989    /// Wraps the corresponding `MPSNNOptimizerStochasticGradientDescent` encode entry point.
990    pub fn encode_vector(
991        &self,
992        command_buffer: &CommandBuffer,
993        input_gradient_vector: &Vector,
994        input_values_vector: &Vector,
995        input_momentum_vector: Option<&Vector>,
996        result_values_vector: &Vector,
997    ) {
998        let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
999        unsafe {
1000            ffi::mps_nn_optimizer_sgd_encode_vector(
1001                self.ptr,
1002                command_buffer.as_ptr(),
1003                input_gradient_vector.as_ptr(),
1004                input_values_vector.as_ptr(),
1005                input_momentum_ptr,
1006                result_values_vector.as_ptr(),
1007            );
1008        };
1009    }
1010
1011    /// Wraps the corresponding `MPSNNOptimizerStochasticGradientDescent` encode entry point.
1012    pub fn encode_matrix(
1013        &self,
1014        command_buffer: &CommandBuffer,
1015        input_gradient_matrix: &Matrix,
1016        input_values_matrix: &Matrix,
1017        input_momentum_matrix: Option<&Matrix>,
1018        result_values_matrix: &Matrix,
1019    ) {
1020        let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1021        unsafe {
1022            ffi::mps_nn_optimizer_sgd_encode_matrix(
1023                self.ptr,
1024                command_buffer.as_ptr(),
1025                input_gradient_matrix.as_ptr(),
1026                input_values_matrix.as_ptr(),
1027                input_momentum_ptr,
1028                result_values_matrix.as_ptr(),
1029            );
1030        };
1031    }
1032}
1033
1034opaque_handle!(NNOptimizerRmsProp, "Wraps `MPSNNOptimizerRMSProp`.");
1035impl_optimizer_common!(NNOptimizerRmsProp);
1036impl NNOptimizerRmsProp {
1037    /// Wraps a constructor on `MPSNNOptimizerRMSProp`.
1038    #[must_use]
1039    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1040        let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
1041        if ptr.is_null() {
1042            None
1043        } else {
1044            Some(Self { ptr })
1045        }
1046    }
1047
1048    /// Wraps a constructor on `MPSNNOptimizerRMSProp`.
1049    #[must_use]
1050    pub fn new_with_options(
1051        device: &MetalDevice,
1052        decay: f64,
1053        epsilon: f32,
1054        optimizer_descriptor: &NNOptimizerDescriptor,
1055    ) -> Option<Self> {
1056        let ptr = unsafe {
1057            ffi::mps_nn_optimizer_rmsprop_new_with_options(
1058                device.as_ptr(),
1059                decay,
1060                epsilon,
1061                optimizer_descriptor.as_ptr(),
1062            )
1063        };
1064        if ptr.is_null() {
1065            None
1066        } else {
1067            Some(Self { ptr })
1068        }
1069    }
1070
1071    /// Wraps the corresponding `MPSNNOptimizerRMSProp` conversion helper.
1072    #[must_use]
1073    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1074        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1075    }
1076
1077    /// Wraps the corresponding `MPSNNOptimizerRMSProp` method.
1078    #[must_use]
1079    pub fn decay(&self) -> f64 {
1080        unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
1081    }
1082
1083    /// Wraps the corresponding `MPSNNOptimizerRMSProp` method.
1084    #[must_use]
1085    pub fn epsilon(&self) -> f32 {
1086        unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
1087    }
1088
1089    /// Wraps the corresponding `MPSNNOptimizerRMSProp` encode entry point.
1090    pub fn encode_vector(
1091        &self,
1092        command_buffer: &CommandBuffer,
1093        input_gradient_vector: &Vector,
1094        input_values_vector: &Vector,
1095        input_sum_of_squares_vector: &Vector,
1096        result_values_vector: &Vector,
1097    ) {
1098        unsafe {
1099            ffi::mps_nn_optimizer_rmsprop_encode_vector(
1100                self.ptr,
1101                command_buffer.as_ptr(),
1102                input_gradient_vector.as_ptr(),
1103                input_values_vector.as_ptr(),
1104                input_sum_of_squares_vector.as_ptr(),
1105                result_values_vector.as_ptr(),
1106            );
1107        };
1108    }
1109
1110    /// Wraps the corresponding `MPSNNOptimizerRMSProp` encode entry point.
1111    pub fn encode_matrix(
1112        &self,
1113        command_buffer: &CommandBuffer,
1114        input_gradient_matrix: &Matrix,
1115        input_values_matrix: &Matrix,
1116        input_sum_of_squares_matrix: &Matrix,
1117        result_values_matrix: &Matrix,
1118    ) {
1119        unsafe {
1120            ffi::mps_nn_optimizer_rmsprop_encode_matrix(
1121                self.ptr,
1122                command_buffer.as_ptr(),
1123                input_gradient_matrix.as_ptr(),
1124                input_values_matrix.as_ptr(),
1125                input_sum_of_squares_matrix.as_ptr(),
1126                result_values_matrix.as_ptr(),
1127            );
1128        };
1129    }
1130}
1131
1132opaque_handle!(NNOptimizerAdam, "Wraps `MPSNNOptimizerAdam`.");
1133impl_optimizer_common!(NNOptimizerAdam);
1134impl NNOptimizerAdam {
1135    /// Wraps a constructor on `MPSNNOptimizerAdam`.
1136    #[must_use]
1137    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1138        let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
1139        if ptr.is_null() {
1140            None
1141        } else {
1142            Some(Self { ptr })
1143        }
1144    }
1145
1146    /// Wraps a constructor on `MPSNNOptimizerAdam`.
1147    #[must_use]
1148    pub fn new_with_options(
1149        device: &MetalDevice,
1150        beta1: f64,
1151        beta2: f64,
1152        epsilon: f32,
1153        time_step: usize,
1154        optimizer_descriptor: &NNOptimizerDescriptor,
1155    ) -> Option<Self> {
1156        let ptr = unsafe {
1157            ffi::mps_nn_optimizer_adam_new_with_options(
1158                device.as_ptr(),
1159                beta1,
1160                beta2,
1161                epsilon,
1162                time_step,
1163                optimizer_descriptor.as_ptr(),
1164            )
1165        };
1166        if ptr.is_null() {
1167            None
1168        } else {
1169            Some(Self { ptr })
1170        }
1171    }
1172
1173    /// Wraps the corresponding `MPSNNOptimizerAdam` conversion helper.
1174    #[must_use]
1175    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1176        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1177    }
1178
1179    /// Wraps the corresponding `MPSNNOptimizerAdam` method.
1180    #[must_use]
1181    pub fn beta1(&self) -> f64 {
1182        unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1183    }
1184
1185    /// Wraps the corresponding `MPSNNOptimizerAdam` method.
1186    #[must_use]
1187    pub fn beta2(&self) -> f64 {
1188        unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1189    }
1190
1191    /// Wraps the corresponding `MPSNNOptimizerAdam` method.
1192    #[must_use]
1193    pub fn epsilon(&self) -> f32 {
1194        unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1195    }
1196
1197    /// Wraps the corresponding `MPSNNOptimizerAdam` method.
1198    #[must_use]
1199    pub fn time_step(&self) -> usize {
1200        unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1201    }
1202
1203    /// Wraps the corresponding `MPSNNOptimizerAdam` setter.
1204    pub fn set_time_step(&self, value: usize) {
1205        unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1206    }
1207
1208    /// Wraps the corresponding `MPSNNOptimizerAdam` encode entry point.
1209    pub fn encode_vector(
1210        &self,
1211        command_buffer: &CommandBuffer,
1212        input_gradient_vector: &Vector,
1213        input_values_vector: &Vector,
1214        input_momentum_vector: &Vector,
1215        input_velocity_vector: &Vector,
1216        result_values_vector: &Vector,
1217    ) {
1218        unsafe {
1219            ffi::mps_nn_optimizer_adam_encode_vector(
1220                self.ptr,
1221                command_buffer.as_ptr(),
1222                input_gradient_vector.as_ptr(),
1223                input_values_vector.as_ptr(),
1224                input_momentum_vector.as_ptr(),
1225                input_velocity_vector.as_ptr(),
1226                result_values_vector.as_ptr(),
1227            );
1228        };
1229    }
1230
1231    /// Wraps the corresponding `MPSNNOptimizerAdam` encode entry point.
1232    pub fn encode_matrix(
1233        &self,
1234        command_buffer: &CommandBuffer,
1235        input_gradient_matrix: &Matrix,
1236        input_values_matrix: &Matrix,
1237        input_momentum_matrix: &Matrix,
1238        input_velocity_matrix: &Matrix,
1239        result_values_matrix: &Matrix,
1240    ) {
1241        unsafe {
1242            ffi::mps_nn_optimizer_adam_encode_matrix(
1243                self.ptr,
1244                command_buffer.as_ptr(),
1245                input_gradient_matrix.as_ptr(),
1246                input_values_matrix.as_ptr(),
1247                input_momentum_matrix.as_ptr(),
1248                input_velocity_matrix.as_ptr(),
1249                result_values_matrix.as_ptr(),
1250            );
1251        };
1252    }
1253
1254    /// Wraps the corresponding `MPSNNOptimizerAdam` encode entry point.
1255    #[allow(clippy::too_many_arguments)]
1256    pub fn encode_amsgrad_vector(
1257        &self,
1258        command_buffer: &CommandBuffer,
1259        input_gradient_vector: &Vector,
1260        input_values_vector: &Vector,
1261        input_momentum_vector: &Vector,
1262        input_velocity_vector: &Vector,
1263        maximum_velocity_vector: Option<&Vector>,
1264        result_values_vector: &Vector,
1265    ) {
1266        let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1267        unsafe {
1268            ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1269                self.ptr,
1270                command_buffer.as_ptr(),
1271                input_gradient_vector.as_ptr(),
1272                input_values_vector.as_ptr(),
1273                input_momentum_vector.as_ptr(),
1274                input_velocity_vector.as_ptr(),
1275                maximum_velocity_ptr,
1276                result_values_vector.as_ptr(),
1277            );
1278        };
1279    }
1280
1281    /// Wraps the corresponding `MPSNNOptimizerAdam` encode entry point.
1282    #[allow(clippy::too_many_arguments)]
1283    pub fn encode_amsgrad_matrix(
1284        &self,
1285        command_buffer: &CommandBuffer,
1286        input_gradient_matrix: &Matrix,
1287        input_values_matrix: &Matrix,
1288        input_momentum_matrix: &Matrix,
1289        input_velocity_matrix: &Matrix,
1290        maximum_velocity_matrix: Option<&Matrix>,
1291        result_values_matrix: &Matrix,
1292    ) {
1293        let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1294        unsafe {
1295            ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1296                self.ptr,
1297                command_buffer.as_ptr(),
1298                input_gradient_matrix.as_ptr(),
1299                input_values_matrix.as_ptr(),
1300                input_momentum_matrix.as_ptr(),
1301                input_velocity_matrix.as_ptr(),
1302                maximum_velocity_ptr,
1303                result_values_matrix.as_ptr(),
1304            );
1305        };
1306    }
1307}
1308
1309opaque_handle!(RnnDescriptor, "Wraps `MPSRNNDescriptor`.");
1310impl_rnn_descriptor_common!(RnnDescriptor);
1311
1312opaque_handle!(GruDescriptor, "Wraps `MPSGRUDescriptor`.");
1313impl_rnn_descriptor_common!(GruDescriptor);
1314impl GruDescriptor {
1315    /// Wraps a constructor on `MPSGRUDescriptor`.
1316    #[must_use]
1317    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1318        let ptr =
1319            unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1320        if ptr.is_null() {
1321            None
1322        } else {
1323            Some(Self { ptr })
1324        }
1325    }
1326
1327    /// Wraps the corresponding `MPSGRUDescriptor` conversion helper.
1328    #[must_use]
1329    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1330        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1331    }
1332
1333    /// Wraps the corresponding `MPSGRUDescriptor` method.
1334    #[must_use]
1335    pub fn gate_pnorm_value(&self) -> f32 {
1336        unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1337    }
1338
1339    /// Wraps the corresponding `MPSGRUDescriptor` setter.
1340    pub fn set_gate_pnorm_value(&self, value: f32) {
1341        unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1342    }
1343
1344    /// Wraps the corresponding `MPSGRUDescriptor` method.
1345    #[must_use]
1346    pub fn flip_output_gates(&self) -> bool {
1347        unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1348    }
1349
1350    /// Wraps the corresponding `MPSGRUDescriptor` setter.
1351    pub fn set_flip_output_gates(&self, value: bool) {
1352        unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1353    }
1354}
1355
1356opaque_handle!(LstmDescriptor, "Wraps `MPSLSTMDescriptor`.");
1357impl_rnn_descriptor_common!(LstmDescriptor);
1358impl LstmDescriptor {
1359    /// Wraps a constructor on `MPSLSTMDescriptor`.
1360    #[must_use]
1361    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1362        let ptr = unsafe {
1363            ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels)
1364        };
1365        if ptr.is_null() {
1366            None
1367        } else {
1368            Some(Self { ptr })
1369        }
1370    }
1371
1372    /// Wraps the corresponding `MPSLSTMDescriptor` conversion helper.
1373    #[must_use]
1374    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1375        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1376    }
1377
1378    /// Wraps the corresponding `MPSLSTMDescriptor` method.
1379    #[must_use]
1380    pub fn memory_weights_are_diagonal(&self) -> bool {
1381        unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1382    }
1383
1384    /// Wraps the corresponding `MPSLSTMDescriptor` setter.
1385    pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1386        unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1387    }
1388
1389    /// Wraps the corresponding `MPSLSTMDescriptor` method.
1390    #[must_use]
1391    pub fn cell_to_output_neuron_type(&self) -> usize {
1392        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1393    }
1394
1395    /// Wraps the corresponding `MPSLSTMDescriptor` setter.
1396    pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1397        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1398    }
1399
1400    /// Wraps the corresponding `MPSLSTMDescriptor` method.
1401    #[must_use]
1402    pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1403        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1404    }
1405
1406    /// Wraps the corresponding `MPSLSTMDescriptor` setter.
1407    pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1408        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1409    }
1410
1411    /// Wraps the corresponding `MPSLSTMDescriptor` method.
1412    #[must_use]
1413    pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1414        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1415    }
1416
1417    /// Wraps the corresponding `MPSLSTMDescriptor` setter.
1418    pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1419        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1420    }
1421
1422    /// Wraps the corresponding `MPSLSTMDescriptor` method.
1423    #[must_use]
1424    pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1425        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1426    }
1427
1428    /// Wraps the corresponding `MPSLSTMDescriptor` setter.
1429    pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1430        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1431    }
1432}
1433
1434opaque_handle!(RnnRecurrentImageState, "Wraps `MPSRNNRecurrentImageState`.");
1435impl RnnRecurrentImageState {
1436    /// Wraps the corresponding `MPSRNNRecurrentImageState` method.
1437    #[must_use]
1438    pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1439        let ptr = unsafe {
1440            ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index)
1441        };
1442        if ptr.is_null() {
1443            None
1444        } else {
1445            Some(unsafe { Image::from_raw(ptr) })
1446        }
1447    }
1448
1449    /// Wraps the corresponding `MPSRNNRecurrentImageState` method.
1450    #[must_use]
1451    pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1452        let ptr =
1453            unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1454        if ptr.is_null() {
1455            None
1456        } else {
1457            Some(unsafe { Image::from_raw(ptr) })
1458        }
1459    }
1460}
1461
1462opaque_handle!(RnnImageInferenceLayer, "Wraps `MPSRNNImageInferenceLayer`.");
1463impl RnnImageInferenceLayer {
1464    /// Wraps a constructor on `MPSRNNImageInferenceLayer`.
1465    #[must_use]
1466    pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1467        let ptr =
1468            unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1469        if ptr.is_null() {
1470            None
1471        } else {
1472            Some(Self { ptr })
1473        }
1474    }
1475
1476    /// Wraps a constructor on `MPSRNNImageInferenceLayer`.
1477    #[must_use]
1478    pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1479        let handles: Vec<_> = descriptors
1480            .iter()
1481            .map(|descriptor| descriptor.as_ptr())
1482            .collect();
1483        let handles_ptr = if handles.is_empty() {
1484            ptr::null()
1485        } else {
1486            handles.as_ptr()
1487        };
1488        let ptr = unsafe {
1489            ffi::mps_rnn_image_inference_layer_new_stack(
1490                device.as_ptr(),
1491                descriptors.len(),
1492                handles_ptr,
1493            )
1494        };
1495        if ptr.is_null() {
1496            None
1497        } else {
1498            Some(Self { ptr })
1499        }
1500    }
1501
1502    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1503    #[must_use]
1504    pub fn input_feature_channels(&self) -> usize {
1505        unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1506    }
1507
1508    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1509    #[must_use]
1510    pub fn output_feature_channels(&self) -> usize {
1511        unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1512    }
1513
1514    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1515    #[must_use]
1516    pub fn number_of_layers(&self) -> usize {
1517        unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1518    }
1519
1520    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1521    #[must_use]
1522    pub fn recurrent_output_is_temporary(&self) -> bool {
1523        unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1524    }
1525
1526    /// Wraps the corresponding `MPSRNNImageInferenceLayer` setter.
1527    pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1528        unsafe {
1529            ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value);
1530        }
1531    }
1532
1533    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1534    #[must_use]
1535    pub fn store_all_intermediate_states(&self) -> bool {
1536        unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1537    }
1538
1539    /// Wraps the corresponding `MPSRNNImageInferenceLayer` setter.
1540    pub fn set_store_all_intermediate_states(&self, value: bool) {
1541        unsafe {
1542            ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value);
1543        }
1544    }
1545
1546    /// Wraps the corresponding `MPSRNNImageInferenceLayer` method.
1547    #[must_use]
1548    pub fn bidirectional_combine_mode(&self) -> usize {
1549        unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1550    }
1551
1552    /// Wraps the corresponding `MPSRNNImageInferenceLayer` setter.
1553    pub fn set_bidirectional_combine_mode(&self, value: usize) {
1554        unsafe {
1555            ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value);
1556        }
1557    }
1558
1559    /// Wraps the corresponding `MPSRNNImageInferenceLayer` encode entry point.
1560    #[must_use]
1561    pub fn encode_sequence(
1562        &self,
1563        command_buffer: &CommandBuffer,
1564        source_images: &[&Image],
1565        destination_images: &[&Image],
1566        recurrent_input_state: Option<&RnnRecurrentImageState>,
1567    ) -> Option<RnnRecurrentImageState> {
1568        if source_images.len() != destination_images.len() {
1569            return None;
1570        }
1571        let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1572        let destination_handles: Vec<_> = destination_images
1573            .iter()
1574            .map(|image| image.as_ptr())
1575            .collect();
1576        let source_ptr = if source_handles.is_empty() {
1577            ptr::null()
1578        } else {
1579            source_handles.as_ptr()
1580        };
1581        let destination_ptr = if destination_handles.is_empty() {
1582            ptr::null()
1583        } else {
1584            destination_handles.as_ptr()
1585        };
1586        let recurrent_input_ptr =
1587            recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1588        let ptr = unsafe {
1589            ffi::mps_rnn_image_inference_layer_encode_sequence(
1590                self.ptr,
1591                command_buffer.as_ptr(),
1592                source_images.len(),
1593                source_ptr,
1594                destination_ptr,
1595                recurrent_input_ptr,
1596            )
1597        };
1598        if ptr.is_null() {
1599            None
1600        } else {
1601            Some(RnnRecurrentImageState { ptr })
1602        }
1603    }
1604}