Skip to main content

apple_mps/
neural.rs

1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8/// `MPSRNNSequenceDirection` constants.
9pub mod rnn_sequence_direction {
10    pub const FORWARD: usize = 0;
11    pub const BACKWARD: usize = 1;
12}
13
14/// `MPSRNNBidirectionalCombineMode` constants.
15pub mod rnn_bidirectional_combine_mode {
16    pub const NONE: usize = 0;
17    pub const ADD: usize = 1;
18    pub const CONCATENATE: usize = 2;
19}
20
21/// `MPSCNNConvolutionFlags` constants.
22pub mod cnn_convolution_flags {
23    pub const NONE: usize = 0;
24}
25
26/// `MPSCNNConvolutionWeightsLayout` constants.
27pub mod cnn_convolution_weights_layout {
28    pub const OHWI: u32 = 0;
29}
30
31/// `MPSNNConvolutionAccumulatorPrecisionOption` constants.
32pub mod cnn_accumulator_precision_option {
33    pub const HALF: usize = 0;
34    pub const FLOAT: usize = 1;
35}
36
37/// `MPSNNRegularizationType` constants.
38pub mod nn_regularization_type {
39    pub const NONE: usize = 0;
40    pub const L1: usize = 1;
41    pub const L2: usize = 2;
42}
43
44macro_rules! opaque_handle {
45    ($name:ident) => {
46        pub struct $name {
47            ptr: *mut c_void,
48        }
49
50        unsafe impl Send for $name {}
51        unsafe impl Sync for $name {}
52
53        impl Drop for $name {
54            fn drop(&mut self) {
55                if !self.ptr.is_null() {
56                    unsafe { ffi::mps_object_release(self.ptr) };
57                    self.ptr = ptr::null_mut();
58                }
59            }
60        }
61
62        impl $name {
63            #[must_use]
64            pub const fn as_ptr(&self) -> *mut c_void {
65                self.ptr
66            }
67        }
68    };
69}
70
71macro_rules! impl_filter_result_image {
72    ($name:ident) => {
73        impl $name {
74            #[must_use]
75            pub fn result_image(&self) -> Option<NNImageNode> {
76                let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
77                if ptr.is_null() {
78                    None
79                } else {
80                    Some(NNImageNode { ptr })
81                }
82            }
83        }
84    };
85}
86
87fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
88    let retained = unsafe { ffi::mps_object_retain(ptr) };
89    if retained.is_null() {
90        None
91    } else {
92        Some(retained)
93    }
94}
95
96macro_rules! impl_rnn_descriptor_common {
97    ($name:ident) => {
98        impl $name {
99            #[must_use]
100            pub fn input_feature_channels(&self) -> usize {
101                unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
102            }
103
104            pub fn set_input_feature_channels(&self, value: usize) {
105                unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
106            }
107
108            #[must_use]
109            pub fn output_feature_channels(&self) -> usize {
110                unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
111            }
112
113            pub fn set_output_feature_channels(&self, value: usize) {
114                unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
115            }
116
117            #[must_use]
118            pub fn use_layer_input_unit_transform_mode(&self) -> bool {
119                unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
120            }
121
122            pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
123                unsafe {
124                    ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
125                };
126            }
127
128            #[must_use]
129            pub fn use_float32_weights(&self) -> bool {
130                unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
131            }
132
133            pub fn set_use_float32_weights(&self, value: bool) {
134                unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
135            }
136
137            #[must_use]
138            pub fn layer_sequence_direction(&self) -> usize {
139                unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
140            }
141
142            pub fn set_layer_sequence_direction(&self, value: usize) {
143                unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
144            }
145        }
146    };
147}
148
149macro_rules! impl_optimizer_common {
150    ($name:ident) => {
151        impl $name {
152            #[must_use]
153            pub fn learning_rate(&self) -> f32 {
154                unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
155            }
156
157            pub fn set_learning_rate(&self, value: f32) {
158                unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
159            }
160
161            #[must_use]
162            pub fn gradient_rescale(&self) -> f32 {
163                unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
164            }
165
166            #[must_use]
167            pub fn apply_gradient_clipping(&self) -> bool {
168                unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
169            }
170
171            pub fn set_apply_gradient_clipping(&self, value: bool) {
172                unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
173            }
174
175            #[must_use]
176            pub fn gradient_clip_max(&self) -> f32 {
177                unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
178            }
179
180            #[must_use]
181            pub fn gradient_clip_min(&self) -> f32 {
182                unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
183            }
184
185            #[must_use]
186            pub fn regularization_scale(&self) -> f32 {
187                unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
188            }
189
190            #[must_use]
191            pub fn regularization_type(&self) -> usize {
192                unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
193            }
194        }
195    };
196}
197
198opaque_handle!(NNImageNode);
199impl NNImageNode {
200    #[must_use]
201    pub fn new() -> Option<Self> {
202        let ptr = unsafe { ffi::mps_nn_image_node_new() };
203        if ptr.is_null() {
204            None
205        } else {
206            Some(Self { ptr })
207        }
208    }
209
210    #[must_use]
211    pub fn exported() -> Option<Self> {
212        let ptr = unsafe { ffi::mps_nn_image_node_exported() };
213        if ptr.is_null() {
214            None
215        } else {
216            Some(Self { ptr })
217        }
218    }
219
220    #[must_use]
221    pub fn format(&self) -> usize {
222        unsafe { ffi::mps_nn_image_node_format(self.ptr) }
223    }
224
225    pub fn set_format(&self, format: usize) {
226        unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
227    }
228
229    #[must_use]
230    pub fn export_from_graph(&self) -> bool {
231        unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
232    }
233
234    pub fn set_export_from_graph(&self, export: bool) {
235        unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
236    }
237
238    #[must_use]
239    pub fn synchronize_resource(&self) -> bool {
240        unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
241    }
242
243    pub fn set_synchronize_resource(&self, synchronize: bool) {
244        unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
245    }
246
247    pub fn use_default_allocator(&self) {
248        unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
249    }
250}
251
252opaque_handle!(CnnNeuronReluNode);
253impl CnnNeuronReluNode {
254    #[must_use]
255    pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
256        let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
257        if ptr.is_null() {
258            None
259        } else {
260            Some(Self { ptr })
261        }
262    }
263}
264impl_filter_result_image!(CnnNeuronReluNode);
265
266opaque_handle!(CnnPoolingMaxNode);
267impl CnnPoolingMaxNode {
268    #[must_use]
269    pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
270        let ptr =
271            unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
272        if ptr.is_null() {
273            None
274        } else {
275            Some(Self { ptr })
276        }
277    }
278}
279impl_filter_result_image!(CnnPoolingMaxNode);
280
281opaque_handle!(CnnSoftMaxNode);
282impl CnnSoftMaxNode {
283    #[must_use]
284    pub fn new(source: &NNImageNode) -> Option<Self> {
285        let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
286        if ptr.is_null() {
287            None
288        } else {
289            Some(Self { ptr })
290        }
291    }
292}
293impl_filter_result_image!(CnnSoftMaxNode);
294
295opaque_handle!(CnnUpsamplingNearestNode);
296impl CnnUpsamplingNearestNode {
297    #[must_use]
298    pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
299        let ptr =
300            unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
301        if ptr.is_null() {
302            None
303        } else {
304            Some(Self { ptr })
305        }
306    }
307}
308impl_filter_result_image!(CnnUpsamplingNearestNode);
309
310opaque_handle!(NNGraph);
311impl NNGraph {
312    #[must_use]
313    pub fn new(
314        device: &MetalDevice,
315        result_image: &NNImageNode,
316        result_image_is_needed: bool,
317    ) -> Option<Self> {
318        let ptr = unsafe {
319            ffi::mps_nn_graph_new(
320                device.as_ptr(),
321                result_image.as_ptr(),
322                result_image_is_needed,
323            )
324        };
325        if ptr.is_null() {
326            None
327        } else {
328            Some(Self { ptr })
329        }
330    }
331
332    #[must_use]
333    pub fn source_image_count(&self) -> usize {
334        unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
335    }
336
337    #[must_use]
338    pub fn format(&self) -> usize {
339        unsafe { ffi::mps_nn_graph_format(self.ptr) }
340    }
341
342    pub fn set_format(&self, format: usize) {
343        unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
344    }
345
346    pub fn set_output_state_is_temporary(&self, temporary: bool) {
347        unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
348    }
349
350    pub fn use_default_destination_image_allocator(&self) {
351        unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
352    }
353
354    pub fn reload_from_data_sources(&self) {
355        unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
356    }
357
358    #[must_use]
359    pub fn encode(
360        &self,
361        command_buffer: &CommandBuffer,
362        source_images: &[&Image],
363    ) -> Option<Image> {
364        let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
365        let source_handles = if handles.is_empty() {
366            ptr::null()
367        } else {
368            handles.as_ptr()
369        };
370        let ptr = unsafe {
371            ffi::mps_nn_graph_encode(
372                self.ptr,
373                command_buffer.as_ptr(),
374                source_images.len(),
375                source_handles,
376            )
377        };
378        if ptr.is_null() {
379            None
380        } else {
381            Some(unsafe { Image::from_raw(ptr) })
382        }
383    }
384}
385
386opaque_handle!(CnnConvolutionDescriptor);
387impl CnnConvolutionDescriptor {
388    #[must_use]
389    pub fn new(
390        kernel_width: usize,
391        kernel_height: usize,
392        input_feature_channels: usize,
393        output_feature_channels: usize,
394    ) -> Option<Self> {
395        let ptr = unsafe {
396            ffi::mps_cnn_convolution_descriptor_new(
397                kernel_width,
398                kernel_height,
399                input_feature_channels,
400                output_feature_channels,
401            )
402        };
403        if ptr.is_null() {
404            None
405        } else {
406            Some(Self { ptr })
407        }
408    }
409
410    #[must_use]
411    pub fn kernel_width(&self) -> usize {
412        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
413    }
414
415    #[must_use]
416    pub fn kernel_height(&self) -> usize {
417        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
418    }
419
420    #[must_use]
421    pub fn input_feature_channels(&self) -> usize {
422        unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
423    }
424
425    #[must_use]
426    pub fn output_feature_channels(&self) -> usize {
427        unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
428    }
429
430    #[must_use]
431    pub fn stride_in_pixels_x(&self) -> usize {
432        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
433    }
434
435    pub fn set_stride_in_pixels_x(&self, value: usize) {
436        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
437    }
438
439    #[must_use]
440    pub fn stride_in_pixels_y(&self) -> usize {
441        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
442    }
443
444    pub fn set_stride_in_pixels_y(&self, value: usize) {
445        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
446    }
447
448    #[must_use]
449    pub fn groups(&self) -> usize {
450        unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
451    }
452
453    pub fn set_groups(&self, value: usize) {
454        unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
455    }
456
457    #[must_use]
458    pub fn dilation_rate_x(&self) -> usize {
459        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
460    }
461
462    pub fn set_dilation_rate_x(&self, value: usize) {
463        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
464    }
465
466    #[must_use]
467    pub fn dilation_rate_y(&self) -> usize {
468        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
469    }
470
471    pub fn set_dilation_rate_y(&self, value: usize) {
472        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
473    }
474}
475
476opaque_handle!(RnnSingleGateDescriptor);
477impl RnnSingleGateDescriptor {
478    #[must_use]
479    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
480        let ptr = unsafe {
481            ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
482        };
483        if ptr.is_null() {
484            None
485        } else {
486            Some(Self { ptr })
487        }
488    }
489
490    #[must_use]
491    pub fn input_feature_channels(&self) -> usize {
492        unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
493    }
494
495    pub fn set_input_feature_channels(&self, value: usize) {
496        unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
497    }
498
499    #[must_use]
500    pub fn output_feature_channels(&self) -> usize {
501        unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
502    }
503
504    pub fn set_output_feature_channels(&self, value: usize) {
505        unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
506    }
507
508    #[must_use]
509    pub fn use_layer_input_unit_transform_mode(&self) -> bool {
510        unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
511    }
512
513    pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
514        unsafe {
515            ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
516                self.ptr, value,
517            );
518        };
519    }
520
521    #[must_use]
522    pub fn use_float32_weights(&self) -> bool {
523        unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
524    }
525
526    pub fn set_use_float32_weights(&self, value: bool) {
527        unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
528    }
529
530    #[must_use]
531    pub fn layer_sequence_direction(&self) -> usize {
532        unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
533    }
534
535    pub fn set_layer_sequence_direction(&self, value: usize) {
536        unsafe {
537            ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
538        };
539    }
540
541    #[must_use]
542    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
543        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
544    }
545}
546
547opaque_handle!(CnnConvolution);
548impl CnnConvolution {
549    #[must_use]
550    pub fn new(
551        device: &MetalDevice,
552        descriptor: &CnnConvolutionDescriptor,
553        kernel_weights: &[f32],
554        bias_terms: Option<&[f32]>,
555        flags: usize,
556    ) -> Option<Self> {
557        if kernel_weights.is_empty() {
558            return None;
559        }
560        let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
561        let ptr = unsafe {
562            ffi::mps_cnn_convolution_new(
563                device.as_ptr(),
564                descriptor.as_ptr(),
565                kernel_weights.as_ptr(),
566                bias_terms_ptr,
567                flags,
568            )
569        };
570        if ptr.is_null() {
571            None
572        } else {
573            Some(Self { ptr })
574        }
575    }
576
577    #[must_use]
578    pub fn input_feature_channels(&self) -> usize {
579        unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
580    }
581
582    #[must_use]
583    pub fn output_feature_channels(&self) -> usize {
584        unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
585    }
586
587    #[must_use]
588    pub fn groups(&self) -> usize {
589        unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
590    }
591
592    #[must_use]
593    pub fn sub_pixel_scale_factor(&self) -> usize {
594        unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
595    }
596
597    #[must_use]
598    pub fn channel_multiplier(&self) -> usize {
599        unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
600    }
601
602    #[must_use]
603    pub fn accumulator_precision_option(&self) -> usize {
604        unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
605    }
606
607    pub fn set_accumulator_precision_option(&self, value: usize) {
608        unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
609    }
610
611    pub fn encode_image(&self, command_buffer: &CommandBuffer, source: &Image, destination: &Image) {
612        unsafe {
613            ffi::mps_cnn_convolution_encode_image(
614                self.ptr,
615                command_buffer.as_ptr(),
616                source.as_ptr(),
617                destination.as_ptr(),
618            );
619        };
620    }
621}
622
623opaque_handle!(CnnConvolutionWeightsAndBiasesState);
624impl CnnConvolutionWeightsAndBiasesState {
625    #[must_use]
626    pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
627        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
628        let ptr = unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr) };
629        if ptr.is_null() {
630            None
631        } else {
632            Some(Self { ptr })
633        }
634    }
635
636    #[must_use]
637    pub fn new_with_offsets(
638        weights: &MetalBuffer,
639        weights_offset: usize,
640        biases: Option<&MetalBuffer>,
641        biases_offset: usize,
642        descriptor: &CnnConvolutionDescriptor,
643    ) -> Option<Self> {
644        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
645        let ptr = unsafe {
646            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
647                weights.as_ptr(),
648                weights_offset,
649                biases_ptr,
650                biases_offset,
651                descriptor.as_ptr(),
652            )
653        };
654        if ptr.is_null() {
655            None
656        } else {
657            Some(Self { ptr })
658        }
659    }
660
661    #[must_use]
662    pub fn new_with_device(
663        device: &MetalDevice,
664        descriptor: &CnnConvolutionDescriptor,
665    ) -> Option<Self> {
666        let ptr = unsafe {
667            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
668                device.as_ptr(),
669                descriptor.as_ptr(),
670            )
671        };
672        if ptr.is_null() {
673            None
674        } else {
675            Some(Self { ptr })
676        }
677    }
678
679    #[must_use]
680    pub fn weights_offset(&self) -> usize {
681        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
682    }
683
684    #[must_use]
685    pub fn biases_offset(&self) -> usize {
686        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
687    }
688}
689
690opaque_handle!(NNOptimizerDescriptor);
691impl NNOptimizerDescriptor {
692    #[must_use]
693    pub fn new(
694        learning_rate: f32,
695        gradient_rescale: f32,
696        regularization_type: usize,
697        regularization_scale: f32,
698    ) -> Option<Self> {
699        let ptr = unsafe {
700            ffi::mps_nn_optimizer_descriptor_new(
701                learning_rate,
702                gradient_rescale,
703                regularization_type,
704                regularization_scale,
705            )
706        };
707        if ptr.is_null() {
708            None
709        } else {
710            Some(Self { ptr })
711        }
712    }
713
714    #[must_use]
715    pub fn with_gradient_clipping(
716        learning_rate: f32,
717        gradient_rescale: f32,
718        apply_gradient_clipping: bool,
719        gradient_clip_max: f32,
720        gradient_clip_min: f32,
721        regularization_type: usize,
722        regularization_scale: f32,
723    ) -> Option<Self> {
724        let ptr = unsafe {
725            ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
726                learning_rate,
727                gradient_rescale,
728                apply_gradient_clipping,
729                gradient_clip_max,
730                gradient_clip_min,
731                regularization_type,
732                regularization_scale,
733            )
734        };
735        if ptr.is_null() {
736            None
737        } else {
738            Some(Self { ptr })
739        }
740    }
741
742    #[must_use]
743    pub fn learning_rate(&self) -> f32 {
744        unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
745    }
746
747    pub fn set_learning_rate(&self, value: f32) {
748        unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
749    }
750
751    #[must_use]
752    pub fn gradient_rescale(&self) -> f32 {
753        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
754    }
755
756    pub fn set_gradient_rescale(&self, value: f32) {
757        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
758    }
759
760    #[must_use]
761    pub fn apply_gradient_clipping(&self) -> bool {
762        unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
763    }
764
765    pub fn set_apply_gradient_clipping(&self, value: bool) {
766        unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
767    }
768
769    #[must_use]
770    pub fn gradient_clip_max(&self) -> f32 {
771        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
772    }
773
774    pub fn set_gradient_clip_max(&self, value: f32) {
775        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
776    }
777
778    #[must_use]
779    pub fn gradient_clip_min(&self) -> f32 {
780        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
781    }
782
783    pub fn set_gradient_clip_min(&self, value: f32) {
784        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
785    }
786
787    #[must_use]
788    pub fn regularization_scale(&self) -> f32 {
789        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
790    }
791
792    pub fn set_regularization_scale(&self, value: f32) {
793        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
794    }
795
796    #[must_use]
797    pub fn regularization_type(&self) -> usize {
798        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
799    }
800
801    pub fn set_regularization_type(&self, value: usize) {
802        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
803    }
804}
805
806opaque_handle!(NNOptimizer);
807impl_optimizer_common!(NNOptimizer);
808
809opaque_handle!(NNOptimizerStochasticGradientDescent);
810impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
811impl NNOptimizerStochasticGradientDescent {
812    #[must_use]
813    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
814        let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
815        if ptr.is_null() {
816            None
817        } else {
818            Some(Self { ptr })
819        }
820    }
821
822    #[must_use]
823    pub fn new_with_options(
824        device: &MetalDevice,
825        momentum_scale: f32,
826        use_nesterov_momentum: bool,
827        optimizer_descriptor: &NNOptimizerDescriptor,
828    ) -> Option<Self> {
829        let ptr = unsafe {
830            ffi::mps_nn_optimizer_sgd_new_with_options(
831                device.as_ptr(),
832                momentum_scale,
833                use_nesterov_momentum,
834                optimizer_descriptor.as_ptr(),
835            )
836        };
837        if ptr.is_null() {
838            None
839        } else {
840            Some(Self { ptr })
841        }
842    }
843
844    #[must_use]
845    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
846        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
847    }
848
849    #[must_use]
850    pub fn momentum_scale(&self) -> f32 {
851        unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
852    }
853
854    #[must_use]
855    pub fn use_nesterov_momentum(&self) -> bool {
856        unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
857    }
858
859    pub fn encode_vector(
860        &self,
861        command_buffer: &CommandBuffer,
862        input_gradient_vector: &Vector,
863        input_values_vector: &Vector,
864        input_momentum_vector: Option<&Vector>,
865        result_values_vector: &Vector,
866    ) {
867        let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
868        unsafe {
869            ffi::mps_nn_optimizer_sgd_encode_vector(
870                self.ptr,
871                command_buffer.as_ptr(),
872                input_gradient_vector.as_ptr(),
873                input_values_vector.as_ptr(),
874                input_momentum_ptr,
875                result_values_vector.as_ptr(),
876            );
877        };
878    }
879
880    pub fn encode_matrix(
881        &self,
882        command_buffer: &CommandBuffer,
883        input_gradient_matrix: &Matrix,
884        input_values_matrix: &Matrix,
885        input_momentum_matrix: Option<&Matrix>,
886        result_values_matrix: &Matrix,
887    ) {
888        let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
889        unsafe {
890            ffi::mps_nn_optimizer_sgd_encode_matrix(
891                self.ptr,
892                command_buffer.as_ptr(),
893                input_gradient_matrix.as_ptr(),
894                input_values_matrix.as_ptr(),
895                input_momentum_ptr,
896                result_values_matrix.as_ptr(),
897            );
898        };
899    }
900}
901
902opaque_handle!(NNOptimizerRmsProp);
903impl_optimizer_common!(NNOptimizerRmsProp);
904impl NNOptimizerRmsProp {
905    #[must_use]
906    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
907        let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
908        if ptr.is_null() {
909            None
910        } else {
911            Some(Self { ptr })
912        }
913    }
914
915    #[must_use]
916    pub fn new_with_options(
917        device: &MetalDevice,
918        decay: f64,
919        epsilon: f32,
920        optimizer_descriptor: &NNOptimizerDescriptor,
921    ) -> Option<Self> {
922        let ptr = unsafe {
923            ffi::mps_nn_optimizer_rmsprop_new_with_options(
924                device.as_ptr(),
925                decay,
926                epsilon,
927                optimizer_descriptor.as_ptr(),
928            )
929        };
930        if ptr.is_null() {
931            None
932        } else {
933            Some(Self { ptr })
934        }
935    }
936
937    #[must_use]
938    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
939        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
940    }
941
942    #[must_use]
943    pub fn decay(&self) -> f64 {
944        unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
945    }
946
947    #[must_use]
948    pub fn epsilon(&self) -> f32 {
949        unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
950    }
951
952    pub fn encode_vector(
953        &self,
954        command_buffer: &CommandBuffer,
955        input_gradient_vector: &Vector,
956        input_values_vector: &Vector,
957        input_sum_of_squares_vector: &Vector,
958        result_values_vector: &Vector,
959    ) {
960        unsafe {
961            ffi::mps_nn_optimizer_rmsprop_encode_vector(
962                self.ptr,
963                command_buffer.as_ptr(),
964                input_gradient_vector.as_ptr(),
965                input_values_vector.as_ptr(),
966                input_sum_of_squares_vector.as_ptr(),
967                result_values_vector.as_ptr(),
968            );
969        };
970    }
971
972    pub fn encode_matrix(
973        &self,
974        command_buffer: &CommandBuffer,
975        input_gradient_matrix: &Matrix,
976        input_values_matrix: &Matrix,
977        input_sum_of_squares_matrix: &Matrix,
978        result_values_matrix: &Matrix,
979    ) {
980        unsafe {
981            ffi::mps_nn_optimizer_rmsprop_encode_matrix(
982                self.ptr,
983                command_buffer.as_ptr(),
984                input_gradient_matrix.as_ptr(),
985                input_values_matrix.as_ptr(),
986                input_sum_of_squares_matrix.as_ptr(),
987                result_values_matrix.as_ptr(),
988            );
989        };
990    }
991}
992
993opaque_handle!(NNOptimizerAdam);
994impl_optimizer_common!(NNOptimizerAdam);
995impl NNOptimizerAdam {
996    #[must_use]
997    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
998        let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
999        if ptr.is_null() {
1000            None
1001        } else {
1002            Some(Self { ptr })
1003        }
1004    }
1005
1006    #[must_use]
1007    pub fn new_with_options(
1008        device: &MetalDevice,
1009        beta1: f64,
1010        beta2: f64,
1011        epsilon: f32,
1012        time_step: usize,
1013        optimizer_descriptor: &NNOptimizerDescriptor,
1014    ) -> Option<Self> {
1015        let ptr = unsafe {
1016            ffi::mps_nn_optimizer_adam_new_with_options(
1017                device.as_ptr(),
1018                beta1,
1019                beta2,
1020                epsilon,
1021                time_step,
1022                optimizer_descriptor.as_ptr(),
1023            )
1024        };
1025        if ptr.is_null() {
1026            None
1027        } else {
1028            Some(Self { ptr })
1029        }
1030    }
1031
1032    #[must_use]
1033    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1034        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1035    }
1036
1037    #[must_use]
1038    pub fn beta1(&self) -> f64 {
1039        unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1040    }
1041
1042    #[must_use]
1043    pub fn beta2(&self) -> f64 {
1044        unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1045    }
1046
1047    #[must_use]
1048    pub fn epsilon(&self) -> f32 {
1049        unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1050    }
1051
1052    #[must_use]
1053    pub fn time_step(&self) -> usize {
1054        unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1055    }
1056
1057    pub fn set_time_step(&self, value: usize) {
1058        unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1059    }
1060
1061    pub fn encode_vector(
1062        &self,
1063        command_buffer: &CommandBuffer,
1064        input_gradient_vector: &Vector,
1065        input_values_vector: &Vector,
1066        input_momentum_vector: &Vector,
1067        input_velocity_vector: &Vector,
1068        result_values_vector: &Vector,
1069    ) {
1070        unsafe {
1071            ffi::mps_nn_optimizer_adam_encode_vector(
1072                self.ptr,
1073                command_buffer.as_ptr(),
1074                input_gradient_vector.as_ptr(),
1075                input_values_vector.as_ptr(),
1076                input_momentum_vector.as_ptr(),
1077                input_velocity_vector.as_ptr(),
1078                result_values_vector.as_ptr(),
1079            );
1080        };
1081    }
1082
1083    pub fn encode_matrix(
1084        &self,
1085        command_buffer: &CommandBuffer,
1086        input_gradient_matrix: &Matrix,
1087        input_values_matrix: &Matrix,
1088        input_momentum_matrix: &Matrix,
1089        input_velocity_matrix: &Matrix,
1090        result_values_matrix: &Matrix,
1091    ) {
1092        unsafe {
1093            ffi::mps_nn_optimizer_adam_encode_matrix(
1094                self.ptr,
1095                command_buffer.as_ptr(),
1096                input_gradient_matrix.as_ptr(),
1097                input_values_matrix.as_ptr(),
1098                input_momentum_matrix.as_ptr(),
1099                input_velocity_matrix.as_ptr(),
1100                result_values_matrix.as_ptr(),
1101            );
1102        };
1103    }
1104
1105    #[allow(clippy::too_many_arguments)]
1106    pub fn encode_amsgrad_vector(
1107        &self,
1108        command_buffer: &CommandBuffer,
1109        input_gradient_vector: &Vector,
1110        input_values_vector: &Vector,
1111        input_momentum_vector: &Vector,
1112        input_velocity_vector: &Vector,
1113        maximum_velocity_vector: Option<&Vector>,
1114        result_values_vector: &Vector,
1115    ) {
1116        let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1117        unsafe {
1118            ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1119                self.ptr,
1120                command_buffer.as_ptr(),
1121                input_gradient_vector.as_ptr(),
1122                input_values_vector.as_ptr(),
1123                input_momentum_vector.as_ptr(),
1124                input_velocity_vector.as_ptr(),
1125                maximum_velocity_ptr,
1126                result_values_vector.as_ptr(),
1127            );
1128        };
1129    }
1130
1131    #[allow(clippy::too_many_arguments)]
1132    pub fn encode_amsgrad_matrix(
1133        &self,
1134        command_buffer: &CommandBuffer,
1135        input_gradient_matrix: &Matrix,
1136        input_values_matrix: &Matrix,
1137        input_momentum_matrix: &Matrix,
1138        input_velocity_matrix: &Matrix,
1139        maximum_velocity_matrix: Option<&Matrix>,
1140        result_values_matrix: &Matrix,
1141    ) {
1142        let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1143        unsafe {
1144            ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1145                self.ptr,
1146                command_buffer.as_ptr(),
1147                input_gradient_matrix.as_ptr(),
1148                input_values_matrix.as_ptr(),
1149                input_momentum_matrix.as_ptr(),
1150                input_velocity_matrix.as_ptr(),
1151                maximum_velocity_ptr,
1152                result_values_matrix.as_ptr(),
1153            );
1154        };
1155    }
1156}
1157
1158opaque_handle!(RnnDescriptor);
1159impl_rnn_descriptor_common!(RnnDescriptor);
1160
1161opaque_handle!(GruDescriptor);
1162impl_rnn_descriptor_common!(GruDescriptor);
1163impl GruDescriptor {
1164    #[must_use]
1165    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1166        let ptr = unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1167        if ptr.is_null() {
1168            None
1169        } else {
1170            Some(Self { ptr })
1171        }
1172    }
1173
1174    #[must_use]
1175    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1176        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1177    }
1178
1179    #[must_use]
1180    pub fn gate_pnorm_value(&self) -> f32 {
1181        unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1182    }
1183
1184    pub fn set_gate_pnorm_value(&self, value: f32) {
1185        unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1186    }
1187
1188    #[must_use]
1189    pub fn flip_output_gates(&self) -> bool {
1190        unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1191    }
1192
1193    pub fn set_flip_output_gates(&self, value: bool) {
1194        unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1195    }
1196}
1197
1198opaque_handle!(LstmDescriptor);
1199impl_rnn_descriptor_common!(LstmDescriptor);
1200impl LstmDescriptor {
1201    #[must_use]
1202    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1203        let ptr = unsafe { ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels) };
1204        if ptr.is_null() {
1205            None
1206        } else {
1207            Some(Self { ptr })
1208        }
1209    }
1210
1211    #[must_use]
1212    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1213        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1214    }
1215
1216    #[must_use]
1217    pub fn memory_weights_are_diagonal(&self) -> bool {
1218        unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1219    }
1220
1221    pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1222        unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1223    }
1224
1225    #[must_use]
1226    pub fn cell_to_output_neuron_type(&self) -> usize {
1227        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1228    }
1229
1230    pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1231        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1232    }
1233
1234    #[must_use]
1235    pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1236        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1237    }
1238
1239    pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1240        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1241    }
1242
1243    #[must_use]
1244    pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1245        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1246    }
1247
1248    pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1249        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1250    }
1251
1252    #[must_use]
1253    pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1254        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1255    }
1256
1257    pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1258        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1259    }
1260}
1261
1262opaque_handle!(RnnRecurrentImageState);
1263impl RnnRecurrentImageState {
1264    #[must_use]
1265    pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1266        let ptr = unsafe { ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index) };
1267        if ptr.is_null() {
1268            None
1269        } else {
1270            Some(unsafe { Image::from_raw(ptr) })
1271        }
1272    }
1273
1274    #[must_use]
1275    pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1276        let ptr = unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1277        if ptr.is_null() {
1278            None
1279        } else {
1280            Some(unsafe { Image::from_raw(ptr) })
1281        }
1282    }
1283}
1284
1285opaque_handle!(RnnImageInferenceLayer);
1286impl RnnImageInferenceLayer {
1287    #[must_use]
1288    pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1289        let ptr = unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1290        if ptr.is_null() {
1291            None
1292        } else {
1293            Some(Self { ptr })
1294        }
1295    }
1296
1297    #[must_use]
1298    pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1299        let handles: Vec<_> = descriptors.iter().map(|descriptor| descriptor.as_ptr()).collect();
1300        let handles_ptr = if handles.is_empty() {
1301            ptr::null()
1302        } else {
1303            handles.as_ptr()
1304        };
1305        let ptr = unsafe {
1306            ffi::mps_rnn_image_inference_layer_new_stack(
1307                device.as_ptr(),
1308                descriptors.len(),
1309                handles_ptr,
1310            )
1311        };
1312        if ptr.is_null() {
1313            None
1314        } else {
1315            Some(Self { ptr })
1316        }
1317    }
1318
1319    #[must_use]
1320    pub fn input_feature_channels(&self) -> usize {
1321        unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1322    }
1323
1324    #[must_use]
1325    pub fn output_feature_channels(&self) -> usize {
1326        unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1327    }
1328
1329    #[must_use]
1330    pub fn number_of_layers(&self) -> usize {
1331        unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1332    }
1333
1334    #[must_use]
1335    pub fn recurrent_output_is_temporary(&self) -> bool {
1336        unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1337    }
1338
1339    pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1340        unsafe { ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value) };
1341    }
1342
1343    #[must_use]
1344    pub fn store_all_intermediate_states(&self) -> bool {
1345        unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1346    }
1347
1348    pub fn set_store_all_intermediate_states(&self, value: bool) {
1349        unsafe { ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value) };
1350    }
1351
1352    #[must_use]
1353    pub fn bidirectional_combine_mode(&self) -> usize {
1354        unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1355    }
1356
1357    pub fn set_bidirectional_combine_mode(&self, value: usize) {
1358        unsafe { ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value) };
1359    }
1360
1361    #[must_use]
1362    pub fn encode_sequence(
1363        &self,
1364        command_buffer: &CommandBuffer,
1365        source_images: &[&Image],
1366        destination_images: &[&Image],
1367        recurrent_input_state: Option<&RnnRecurrentImageState>,
1368    ) -> Option<RnnRecurrentImageState> {
1369        if source_images.len() != destination_images.len() {
1370            return None;
1371        }
1372        let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1373        let destination_handles: Vec<_> = destination_images.iter().map(|image| image.as_ptr()).collect();
1374        let source_ptr = if source_handles.is_empty() {
1375            ptr::null()
1376        } else {
1377            source_handles.as_ptr()
1378        };
1379        let destination_ptr = if destination_handles.is_empty() {
1380            ptr::null()
1381        } else {
1382            destination_handles.as_ptr()
1383        };
1384        let recurrent_input_ptr = recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1385        let ptr = unsafe {
1386            ffi::mps_rnn_image_inference_layer_encode_sequence(
1387                self.ptr,
1388                command_buffer.as_ptr(),
1389                source_images.len(),
1390                source_ptr,
1391                destination_ptr,
1392                recurrent_input_ptr,
1393            )
1394        };
1395        if ptr.is_null() {
1396            None
1397        } else {
1398            Some(RnnRecurrentImageState { ptr })
1399        }
1400    }
1401}