Skip to main content

apple_mps/
neural.rs

1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8/// `MPSRNNSequenceDirection` constants.
9pub mod rnn_sequence_direction {
10    pub const FORWARD: usize = 0;
11    pub const BACKWARD: usize = 1;
12}
13
14/// `MPSRNNBidirectionalCombineMode` constants.
15pub mod rnn_bidirectional_combine_mode {
16    pub const NONE: usize = 0;
17    pub const ADD: usize = 1;
18    pub const CONCATENATE: usize = 2;
19}
20
21/// `MPSCNNConvolutionFlags` constants.
22pub mod cnn_convolution_flags {
23    pub const NONE: usize = 0;
24}
25
26/// `MPSCNNConvolutionWeightsLayout` constants.
27pub mod cnn_convolution_weights_layout {
28    pub const OHWI: u32 = 0;
29}
30
31/// `MPSNNConvolutionAccumulatorPrecisionOption` constants.
32pub mod cnn_accumulator_precision_option {
33    pub const HALF: usize = 0;
34    pub const FLOAT: usize = 1;
35}
36
37/// `MPSNNRegularizationType` constants.
38pub mod nn_regularization_type {
39    pub const NONE: usize = 0;
40    pub const L1: usize = 1;
41    pub const L2: usize = 2;
42}
43
44pub use crate::generated::neural::*;
45
46macro_rules! opaque_handle {
47    ($name:ident) => {
48        pub struct $name {
49            ptr: *mut c_void,
50        }
51
52        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
53        unsafe impl Send for $name {}
54        // SAFETY: MPS handles are opaque pointers to thread-safe Swift/ObjC objects.
55        unsafe impl Sync for $name {}
56
57        impl Drop for $name {
58            fn drop(&mut self) {
59                if !self.ptr.is_null() {
60                    // SAFETY: `ptr` is a +1 retained MPS object owned by this wrapper.
61                    unsafe { ffi::mps_object_release(self.ptr) };
62                    self.ptr = ptr::null_mut();
63                }
64            }
65        }
66
67        impl $name {
68            #[must_use]
69            pub const fn as_ptr(&self) -> *mut c_void {
70                self.ptr
71            }
72        }
73    };
74}
75
76macro_rules! impl_filter_result_image {
77    ($name:ident) => {
78        impl $name {
79            #[must_use]
80            pub fn result_image(&self) -> Option<NNImageNode> {
81                let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
82                if ptr.is_null() {
83                    None
84                } else {
85                    Some(NNImageNode { ptr })
86                }
87            }
88        }
89    };
90}
91
92fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
93    let retained = unsafe { ffi::mps_object_retain(ptr) };
94    if retained.is_null() {
95        None
96    } else {
97        Some(retained)
98    }
99}
100
101macro_rules! impl_rnn_descriptor_common {
102    ($name:ident) => {
103        impl $name {
104            #[must_use]
105            pub fn input_feature_channels(&self) -> usize {
106                unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
107            }
108
109            pub fn set_input_feature_channels(&self, value: usize) {
110                unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
111            }
112
113            #[must_use]
114            pub fn output_feature_channels(&self) -> usize {
115                unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
116            }
117
118            pub fn set_output_feature_channels(&self, value: usize) {
119                unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
120            }
121
122            #[must_use]
123            pub fn use_layer_input_unit_transform_mode(&self) -> bool {
124                unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
125            }
126
127            pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
128                unsafe {
129                    ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
130                };
131            }
132
133            #[must_use]
134            pub fn use_float32_weights(&self) -> bool {
135                unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
136            }
137
138            pub fn set_use_float32_weights(&self, value: bool) {
139                unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
140            }
141
142            #[must_use]
143            pub fn layer_sequence_direction(&self) -> usize {
144                unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
145            }
146
147            pub fn set_layer_sequence_direction(&self, value: usize) {
148                unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
149            }
150        }
151    };
152}
153
154macro_rules! impl_optimizer_common {
155    ($name:ident) => {
156        impl $name {
157            #[must_use]
158            pub fn learning_rate(&self) -> f32 {
159                unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
160            }
161
162            pub fn set_learning_rate(&self, value: f32) {
163                unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
164            }
165
166            #[must_use]
167            pub fn gradient_rescale(&self) -> f32 {
168                unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
169            }
170
171            #[must_use]
172            pub fn apply_gradient_clipping(&self) -> bool {
173                unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
174            }
175
176            pub fn set_apply_gradient_clipping(&self, value: bool) {
177                unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
178            }
179
180            #[must_use]
181            pub fn gradient_clip_max(&self) -> f32 {
182                unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
183            }
184
185            #[must_use]
186            pub fn gradient_clip_min(&self) -> f32 {
187                unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
188            }
189
190            #[must_use]
191            pub fn regularization_scale(&self) -> f32 {
192                unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
193            }
194
195            #[must_use]
196            pub fn regularization_type(&self) -> usize {
197                unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
198            }
199        }
200    };
201}
202
203opaque_handle!(NNImageNode);
204impl NNImageNode {
205    #[must_use]
206    pub fn new() -> Option<Self> {
207        let ptr = unsafe { ffi::mps_nn_image_node_new() };
208        if ptr.is_null() {
209            None
210        } else {
211            Some(Self { ptr })
212        }
213    }
214
215    #[must_use]
216    pub fn exported() -> Option<Self> {
217        let ptr = unsafe { ffi::mps_nn_image_node_exported() };
218        if ptr.is_null() {
219            None
220        } else {
221            Some(Self { ptr })
222        }
223    }
224
225    #[must_use]
226    pub fn format(&self) -> usize {
227        unsafe { ffi::mps_nn_image_node_format(self.ptr) }
228    }
229
230    pub fn set_format(&self, format: usize) {
231        unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
232    }
233
234    #[must_use]
235    pub fn export_from_graph(&self) -> bool {
236        unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
237    }
238
239    pub fn set_export_from_graph(&self, export: bool) {
240        unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
241    }
242
243    #[must_use]
244    pub fn synchronize_resource(&self) -> bool {
245        unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
246    }
247
248    pub fn set_synchronize_resource(&self, synchronize: bool) {
249        unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
250    }
251
252    pub fn use_default_allocator(&self) {
253        unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
254    }
255}
256
257opaque_handle!(CnnNeuronReluNode);
258impl CnnNeuronReluNode {
259    #[must_use]
260    pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
261        let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
262        if ptr.is_null() {
263            None
264        } else {
265            Some(Self { ptr })
266        }
267    }
268}
269impl_filter_result_image!(CnnNeuronReluNode);
270
271opaque_handle!(CnnPoolingMaxNode);
272impl CnnPoolingMaxNode {
273    #[must_use]
274    pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
275        let ptr =
276            unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
277        if ptr.is_null() {
278            None
279        } else {
280            Some(Self { ptr })
281        }
282    }
283}
284impl_filter_result_image!(CnnPoolingMaxNode);
285
286opaque_handle!(CnnSoftMaxNode);
287impl CnnSoftMaxNode {
288    #[must_use]
289    pub fn new(source: &NNImageNode) -> Option<Self> {
290        let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
291        if ptr.is_null() {
292            None
293        } else {
294            Some(Self { ptr })
295        }
296    }
297}
298impl_filter_result_image!(CnnSoftMaxNode);
299
300opaque_handle!(CnnUpsamplingNearestNode);
301impl CnnUpsamplingNearestNode {
302    #[must_use]
303    pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
304        let ptr =
305            unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
306        if ptr.is_null() {
307            None
308        } else {
309            Some(Self { ptr })
310        }
311    }
312}
313impl_filter_result_image!(CnnUpsamplingNearestNode);
314
315opaque_handle!(NNGraph);
316impl NNGraph {
317    #[must_use]
318    pub fn new(
319        device: &MetalDevice,
320        result_image: &NNImageNode,
321        result_image_is_needed: bool,
322    ) -> Option<Self> {
323        let ptr = unsafe {
324            ffi::mps_nn_graph_new(
325                device.as_ptr(),
326                result_image.as_ptr(),
327                result_image_is_needed,
328            )
329        };
330        if ptr.is_null() {
331            None
332        } else {
333            Some(Self { ptr })
334        }
335    }
336
337    #[must_use]
338    pub fn source_image_count(&self) -> usize {
339        unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
340    }
341
342    #[must_use]
343    pub fn format(&self) -> usize {
344        unsafe { ffi::mps_nn_graph_format(self.ptr) }
345    }
346
347    pub fn set_format(&self, format: usize) {
348        unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
349    }
350
351    pub fn set_output_state_is_temporary(&self, temporary: bool) {
352        unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
353    }
354
355    pub fn use_default_destination_image_allocator(&self) {
356        unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
357    }
358
359    pub fn reload_from_data_sources(&self) {
360        unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
361    }
362
363    #[must_use]
364    pub fn encode(
365        &self,
366        command_buffer: &CommandBuffer,
367        source_images: &[&Image],
368    ) -> Option<Image> {
369        let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
370        let source_handles = if handles.is_empty() {
371            ptr::null()
372        } else {
373            handles.as_ptr()
374        };
375        let ptr = unsafe {
376            ffi::mps_nn_graph_encode(
377                self.ptr,
378                command_buffer.as_ptr(),
379                source_images.len(),
380                source_handles,
381            )
382        };
383        if ptr.is_null() {
384            None
385        } else {
386            Some(unsafe { Image::from_raw(ptr) })
387        }
388    }
389}
390
391opaque_handle!(CnnConvolutionDescriptor);
392impl CnnConvolutionDescriptor {
393    #[must_use]
394    pub fn new(
395        kernel_width: usize,
396        kernel_height: usize,
397        input_feature_channels: usize,
398        output_feature_channels: usize,
399    ) -> Option<Self> {
400        let ptr = unsafe {
401            ffi::mps_cnn_convolution_descriptor_new(
402                kernel_width,
403                kernel_height,
404                input_feature_channels,
405                output_feature_channels,
406            )
407        };
408        if ptr.is_null() {
409            None
410        } else {
411            Some(Self { ptr })
412        }
413    }
414
415    #[must_use]
416    pub fn kernel_width(&self) -> usize {
417        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
418    }
419
420    #[must_use]
421    pub fn kernel_height(&self) -> usize {
422        unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
423    }
424
425    #[must_use]
426    pub fn input_feature_channels(&self) -> usize {
427        unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
428    }
429
430    #[must_use]
431    pub fn output_feature_channels(&self) -> usize {
432        unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
433    }
434
435    #[must_use]
436    pub fn stride_in_pixels_x(&self) -> usize {
437        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
438    }
439
440    pub fn set_stride_in_pixels_x(&self, value: usize) {
441        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
442    }
443
444    #[must_use]
445    pub fn stride_in_pixels_y(&self) -> usize {
446        unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
447    }
448
449    pub fn set_stride_in_pixels_y(&self, value: usize) {
450        unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
451    }
452
453    #[must_use]
454    pub fn groups(&self) -> usize {
455        unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
456    }
457
458    pub fn set_groups(&self, value: usize) {
459        unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
460    }
461
462    #[must_use]
463    pub fn dilation_rate_x(&self) -> usize {
464        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
465    }
466
467    pub fn set_dilation_rate_x(&self, value: usize) {
468        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
469    }
470
471    #[must_use]
472    pub fn dilation_rate_y(&self) -> usize {
473        unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
474    }
475
476    pub fn set_dilation_rate_y(&self, value: usize) {
477        unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
478    }
479}
480
481opaque_handle!(RnnSingleGateDescriptor);
482impl RnnSingleGateDescriptor {
483    #[must_use]
484    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
485        let ptr = unsafe {
486            ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
487        };
488        if ptr.is_null() {
489            None
490        } else {
491            Some(Self { ptr })
492        }
493    }
494
495    #[must_use]
496    pub fn input_feature_channels(&self) -> usize {
497        unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
498    }
499
500    pub fn set_input_feature_channels(&self, value: usize) {
501        unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
502    }
503
504    #[must_use]
505    pub fn output_feature_channels(&self) -> usize {
506        unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
507    }
508
509    pub fn set_output_feature_channels(&self, value: usize) {
510        unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
511    }
512
513    #[must_use]
514    pub fn use_layer_input_unit_transform_mode(&self) -> bool {
515        unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
516    }
517
518    pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
519        unsafe {
520            ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
521                self.ptr, value,
522            );
523        };
524    }
525
526    #[must_use]
527    pub fn use_float32_weights(&self) -> bool {
528        unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
529    }
530
531    pub fn set_use_float32_weights(&self, value: bool) {
532        unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
533    }
534
535    #[must_use]
536    pub fn layer_sequence_direction(&self) -> usize {
537        unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
538    }
539
540    pub fn set_layer_sequence_direction(&self, value: usize) {
541        unsafe {
542            ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
543        };
544    }
545
546    #[must_use]
547    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
548        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
549    }
550}
551
552opaque_handle!(CnnConvolution);
553impl CnnConvolution {
554    #[must_use]
555    pub fn new(
556        device: &MetalDevice,
557        descriptor: &CnnConvolutionDescriptor,
558        kernel_weights: &[f32],
559        bias_terms: Option<&[f32]>,
560        flags: usize,
561    ) -> Option<Self> {
562        if kernel_weights.is_empty() {
563            return None;
564        }
565        let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
566        let ptr = unsafe {
567            ffi::mps_cnn_convolution_new(
568                device.as_ptr(),
569                descriptor.as_ptr(),
570                kernel_weights.as_ptr(),
571                bias_terms_ptr,
572                flags,
573            )
574        };
575        if ptr.is_null() {
576            None
577        } else {
578            Some(Self { ptr })
579        }
580    }
581
582    #[must_use]
583    pub fn input_feature_channels(&self) -> usize {
584        unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
585    }
586
587    #[must_use]
588    pub fn output_feature_channels(&self) -> usize {
589        unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
590    }
591
592    #[must_use]
593    pub fn groups(&self) -> usize {
594        unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
595    }
596
597    #[must_use]
598    pub fn sub_pixel_scale_factor(&self) -> usize {
599        unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
600    }
601
602    #[must_use]
603    pub fn channel_multiplier(&self) -> usize {
604        unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
605    }
606
607    #[must_use]
608    pub fn accumulator_precision_option(&self) -> usize {
609        unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
610    }
611
612    pub fn set_accumulator_precision_option(&self, value: usize) {
613        unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
614    }
615
616    pub fn encode_image(
617        &self,
618        command_buffer: &CommandBuffer,
619        source: &Image,
620        destination: &Image,
621    ) {
622        unsafe {
623            ffi::mps_cnn_convolution_encode_image(
624                self.ptr,
625                command_buffer.as_ptr(),
626                source.as_ptr(),
627                destination.as_ptr(),
628            );
629        };
630    }
631}
632
633opaque_handle!(CnnConvolutionWeightsAndBiasesState);
634impl CnnConvolutionWeightsAndBiasesState {
635    #[must_use]
636    pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
637        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
638        let ptr = unsafe {
639            ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr)
640        };
641        if ptr.is_null() {
642            None
643        } else {
644            Some(Self { ptr })
645        }
646    }
647
648    #[must_use]
649    pub fn new_with_offsets(
650        weights: &MetalBuffer,
651        weights_offset: usize,
652        biases: Option<&MetalBuffer>,
653        biases_offset: usize,
654        descriptor: &CnnConvolutionDescriptor,
655    ) -> Option<Self> {
656        let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
657        let ptr = unsafe {
658            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
659                weights.as_ptr(),
660                weights_offset,
661                biases_ptr,
662                biases_offset,
663                descriptor.as_ptr(),
664            )
665        };
666        if ptr.is_null() {
667            None
668        } else {
669            Some(Self { ptr })
670        }
671    }
672
673    #[must_use]
674    pub fn new_with_device(
675        device: &MetalDevice,
676        descriptor: &CnnConvolutionDescriptor,
677    ) -> Option<Self> {
678        let ptr = unsafe {
679            ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
680                device.as_ptr(),
681                descriptor.as_ptr(),
682            )
683        };
684        if ptr.is_null() {
685            None
686        } else {
687            Some(Self { ptr })
688        }
689    }
690
691    #[must_use]
692    pub fn weights_offset(&self) -> usize {
693        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
694    }
695
696    #[must_use]
697    pub fn biases_offset(&self) -> usize {
698        unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
699    }
700}
701
702opaque_handle!(NNOptimizerDescriptor);
703impl NNOptimizerDescriptor {
704    #[must_use]
705    pub fn new(
706        learning_rate: f32,
707        gradient_rescale: f32,
708        regularization_type: usize,
709        regularization_scale: f32,
710    ) -> Option<Self> {
711        let ptr = unsafe {
712            ffi::mps_nn_optimizer_descriptor_new(
713                learning_rate,
714                gradient_rescale,
715                regularization_type,
716                regularization_scale,
717            )
718        };
719        if ptr.is_null() {
720            None
721        } else {
722            Some(Self { ptr })
723        }
724    }
725
726    #[must_use]
727    pub fn with_gradient_clipping(
728        learning_rate: f32,
729        gradient_rescale: f32,
730        apply_gradient_clipping: bool,
731        gradient_clip_max: f32,
732        gradient_clip_min: f32,
733        regularization_type: usize,
734        regularization_scale: f32,
735    ) -> Option<Self> {
736        let ptr = unsafe {
737            ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
738                learning_rate,
739                gradient_rescale,
740                apply_gradient_clipping,
741                gradient_clip_max,
742                gradient_clip_min,
743                regularization_type,
744                regularization_scale,
745            )
746        };
747        if ptr.is_null() {
748            None
749        } else {
750            Some(Self { ptr })
751        }
752    }
753
754    #[must_use]
755    pub fn learning_rate(&self) -> f32 {
756        unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
757    }
758
759    pub fn set_learning_rate(&self, value: f32) {
760        unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
761    }
762
763    #[must_use]
764    pub fn gradient_rescale(&self) -> f32 {
765        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
766    }
767
768    pub fn set_gradient_rescale(&self, value: f32) {
769        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
770    }
771
772    #[must_use]
773    pub fn apply_gradient_clipping(&self) -> bool {
774        unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
775    }
776
777    pub fn set_apply_gradient_clipping(&self, value: bool) {
778        unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
779    }
780
781    #[must_use]
782    pub fn gradient_clip_max(&self) -> f32 {
783        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
784    }
785
786    pub fn set_gradient_clip_max(&self, value: f32) {
787        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
788    }
789
790    #[must_use]
791    pub fn gradient_clip_min(&self) -> f32 {
792        unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
793    }
794
795    pub fn set_gradient_clip_min(&self, value: f32) {
796        unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
797    }
798
799    #[must_use]
800    pub fn regularization_scale(&self) -> f32 {
801        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
802    }
803
804    pub fn set_regularization_scale(&self, value: f32) {
805        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
806    }
807
808    #[must_use]
809    pub fn regularization_type(&self) -> usize {
810        unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
811    }
812
813    pub fn set_regularization_type(&self, value: usize) {
814        unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
815    }
816}
817
818opaque_handle!(NNOptimizer);
819impl_optimizer_common!(NNOptimizer);
820
821opaque_handle!(NNOptimizerStochasticGradientDescent);
822impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
823impl NNOptimizerStochasticGradientDescent {
824    #[must_use]
825    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
826        let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
827        if ptr.is_null() {
828            None
829        } else {
830            Some(Self { ptr })
831        }
832    }
833
834    #[must_use]
835    pub fn new_with_options(
836        device: &MetalDevice,
837        momentum_scale: f32,
838        use_nesterov_momentum: bool,
839        optimizer_descriptor: &NNOptimizerDescriptor,
840    ) -> Option<Self> {
841        let ptr = unsafe {
842            ffi::mps_nn_optimizer_sgd_new_with_options(
843                device.as_ptr(),
844                momentum_scale,
845                use_nesterov_momentum,
846                optimizer_descriptor.as_ptr(),
847            )
848        };
849        if ptr.is_null() {
850            None
851        } else {
852            Some(Self { ptr })
853        }
854    }
855
856    #[must_use]
857    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
858        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
859    }
860
861    #[must_use]
862    pub fn momentum_scale(&self) -> f32 {
863        unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
864    }
865
866    #[must_use]
867    pub fn use_nesterov_momentum(&self) -> bool {
868        unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
869    }
870
871    pub fn encode_vector(
872        &self,
873        command_buffer: &CommandBuffer,
874        input_gradient_vector: &Vector,
875        input_values_vector: &Vector,
876        input_momentum_vector: Option<&Vector>,
877        result_values_vector: &Vector,
878    ) {
879        let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
880        unsafe {
881            ffi::mps_nn_optimizer_sgd_encode_vector(
882                self.ptr,
883                command_buffer.as_ptr(),
884                input_gradient_vector.as_ptr(),
885                input_values_vector.as_ptr(),
886                input_momentum_ptr,
887                result_values_vector.as_ptr(),
888            );
889        };
890    }
891
892    pub fn encode_matrix(
893        &self,
894        command_buffer: &CommandBuffer,
895        input_gradient_matrix: &Matrix,
896        input_values_matrix: &Matrix,
897        input_momentum_matrix: Option<&Matrix>,
898        result_values_matrix: &Matrix,
899    ) {
900        let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
901        unsafe {
902            ffi::mps_nn_optimizer_sgd_encode_matrix(
903                self.ptr,
904                command_buffer.as_ptr(),
905                input_gradient_matrix.as_ptr(),
906                input_values_matrix.as_ptr(),
907                input_momentum_ptr,
908                result_values_matrix.as_ptr(),
909            );
910        };
911    }
912}
913
914opaque_handle!(NNOptimizerRmsProp);
915impl_optimizer_common!(NNOptimizerRmsProp);
916impl NNOptimizerRmsProp {
917    #[must_use]
918    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
919        let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
920        if ptr.is_null() {
921            None
922        } else {
923            Some(Self { ptr })
924        }
925    }
926
927    #[must_use]
928    pub fn new_with_options(
929        device: &MetalDevice,
930        decay: f64,
931        epsilon: f32,
932        optimizer_descriptor: &NNOptimizerDescriptor,
933    ) -> Option<Self> {
934        let ptr = unsafe {
935            ffi::mps_nn_optimizer_rmsprop_new_with_options(
936                device.as_ptr(),
937                decay,
938                epsilon,
939                optimizer_descriptor.as_ptr(),
940            )
941        };
942        if ptr.is_null() {
943            None
944        } else {
945            Some(Self { ptr })
946        }
947    }
948
949    #[must_use]
950    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
951        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
952    }
953
954    #[must_use]
955    pub fn decay(&self) -> f64 {
956        unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
957    }
958
959    #[must_use]
960    pub fn epsilon(&self) -> f32 {
961        unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
962    }
963
964    pub fn encode_vector(
965        &self,
966        command_buffer: &CommandBuffer,
967        input_gradient_vector: &Vector,
968        input_values_vector: &Vector,
969        input_sum_of_squares_vector: &Vector,
970        result_values_vector: &Vector,
971    ) {
972        unsafe {
973            ffi::mps_nn_optimizer_rmsprop_encode_vector(
974                self.ptr,
975                command_buffer.as_ptr(),
976                input_gradient_vector.as_ptr(),
977                input_values_vector.as_ptr(),
978                input_sum_of_squares_vector.as_ptr(),
979                result_values_vector.as_ptr(),
980            );
981        };
982    }
983
984    pub fn encode_matrix(
985        &self,
986        command_buffer: &CommandBuffer,
987        input_gradient_matrix: &Matrix,
988        input_values_matrix: &Matrix,
989        input_sum_of_squares_matrix: &Matrix,
990        result_values_matrix: &Matrix,
991    ) {
992        unsafe {
993            ffi::mps_nn_optimizer_rmsprop_encode_matrix(
994                self.ptr,
995                command_buffer.as_ptr(),
996                input_gradient_matrix.as_ptr(),
997                input_values_matrix.as_ptr(),
998                input_sum_of_squares_matrix.as_ptr(),
999                result_values_matrix.as_ptr(),
1000            );
1001        };
1002    }
1003}
1004
1005opaque_handle!(NNOptimizerAdam);
1006impl_optimizer_common!(NNOptimizerAdam);
1007impl NNOptimizerAdam {
1008    #[must_use]
1009    pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1010        let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
1011        if ptr.is_null() {
1012            None
1013        } else {
1014            Some(Self { ptr })
1015        }
1016    }
1017
1018    #[must_use]
1019    pub fn new_with_options(
1020        device: &MetalDevice,
1021        beta1: f64,
1022        beta2: f64,
1023        epsilon: f32,
1024        time_step: usize,
1025        optimizer_descriptor: &NNOptimizerDescriptor,
1026    ) -> Option<Self> {
1027        let ptr = unsafe {
1028            ffi::mps_nn_optimizer_adam_new_with_options(
1029                device.as_ptr(),
1030                beta1,
1031                beta2,
1032                epsilon,
1033                time_step,
1034                optimizer_descriptor.as_ptr(),
1035            )
1036        };
1037        if ptr.is_null() {
1038            None
1039        } else {
1040            Some(Self { ptr })
1041        }
1042    }
1043
1044    #[must_use]
1045    pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1046        retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1047    }
1048
1049    #[must_use]
1050    pub fn beta1(&self) -> f64 {
1051        unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1052    }
1053
1054    #[must_use]
1055    pub fn beta2(&self) -> f64 {
1056        unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1057    }
1058
1059    #[must_use]
1060    pub fn epsilon(&self) -> f32 {
1061        unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1062    }
1063
1064    #[must_use]
1065    pub fn time_step(&self) -> usize {
1066        unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1067    }
1068
1069    pub fn set_time_step(&self, value: usize) {
1070        unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1071    }
1072
1073    pub fn encode_vector(
1074        &self,
1075        command_buffer: &CommandBuffer,
1076        input_gradient_vector: &Vector,
1077        input_values_vector: &Vector,
1078        input_momentum_vector: &Vector,
1079        input_velocity_vector: &Vector,
1080        result_values_vector: &Vector,
1081    ) {
1082        unsafe {
1083            ffi::mps_nn_optimizer_adam_encode_vector(
1084                self.ptr,
1085                command_buffer.as_ptr(),
1086                input_gradient_vector.as_ptr(),
1087                input_values_vector.as_ptr(),
1088                input_momentum_vector.as_ptr(),
1089                input_velocity_vector.as_ptr(),
1090                result_values_vector.as_ptr(),
1091            );
1092        };
1093    }
1094
1095    pub fn encode_matrix(
1096        &self,
1097        command_buffer: &CommandBuffer,
1098        input_gradient_matrix: &Matrix,
1099        input_values_matrix: &Matrix,
1100        input_momentum_matrix: &Matrix,
1101        input_velocity_matrix: &Matrix,
1102        result_values_matrix: &Matrix,
1103    ) {
1104        unsafe {
1105            ffi::mps_nn_optimizer_adam_encode_matrix(
1106                self.ptr,
1107                command_buffer.as_ptr(),
1108                input_gradient_matrix.as_ptr(),
1109                input_values_matrix.as_ptr(),
1110                input_momentum_matrix.as_ptr(),
1111                input_velocity_matrix.as_ptr(),
1112                result_values_matrix.as_ptr(),
1113            );
1114        };
1115    }
1116
1117    #[allow(clippy::too_many_arguments)]
1118    pub fn encode_amsgrad_vector(
1119        &self,
1120        command_buffer: &CommandBuffer,
1121        input_gradient_vector: &Vector,
1122        input_values_vector: &Vector,
1123        input_momentum_vector: &Vector,
1124        input_velocity_vector: &Vector,
1125        maximum_velocity_vector: Option<&Vector>,
1126        result_values_vector: &Vector,
1127    ) {
1128        let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1129        unsafe {
1130            ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1131                self.ptr,
1132                command_buffer.as_ptr(),
1133                input_gradient_vector.as_ptr(),
1134                input_values_vector.as_ptr(),
1135                input_momentum_vector.as_ptr(),
1136                input_velocity_vector.as_ptr(),
1137                maximum_velocity_ptr,
1138                result_values_vector.as_ptr(),
1139            );
1140        };
1141    }
1142
1143    #[allow(clippy::too_many_arguments)]
1144    pub fn encode_amsgrad_matrix(
1145        &self,
1146        command_buffer: &CommandBuffer,
1147        input_gradient_matrix: &Matrix,
1148        input_values_matrix: &Matrix,
1149        input_momentum_matrix: &Matrix,
1150        input_velocity_matrix: &Matrix,
1151        maximum_velocity_matrix: Option<&Matrix>,
1152        result_values_matrix: &Matrix,
1153    ) {
1154        let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1155        unsafe {
1156            ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1157                self.ptr,
1158                command_buffer.as_ptr(),
1159                input_gradient_matrix.as_ptr(),
1160                input_values_matrix.as_ptr(),
1161                input_momentum_matrix.as_ptr(),
1162                input_velocity_matrix.as_ptr(),
1163                maximum_velocity_ptr,
1164                result_values_matrix.as_ptr(),
1165            );
1166        };
1167    }
1168}
1169
1170opaque_handle!(RnnDescriptor);
1171impl_rnn_descriptor_common!(RnnDescriptor);
1172
1173opaque_handle!(GruDescriptor);
1174impl_rnn_descriptor_common!(GruDescriptor);
1175impl GruDescriptor {
1176    #[must_use]
1177    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1178        let ptr =
1179            unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1180        if ptr.is_null() {
1181            None
1182        } else {
1183            Some(Self { ptr })
1184        }
1185    }
1186
1187    #[must_use]
1188    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1189        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1190    }
1191
1192    #[must_use]
1193    pub fn gate_pnorm_value(&self) -> f32 {
1194        unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1195    }
1196
1197    pub fn set_gate_pnorm_value(&self, value: f32) {
1198        unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1199    }
1200
1201    #[must_use]
1202    pub fn flip_output_gates(&self) -> bool {
1203        unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1204    }
1205
1206    pub fn set_flip_output_gates(&self, value: bool) {
1207        unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1208    }
1209}
1210
1211opaque_handle!(LstmDescriptor);
1212impl_rnn_descriptor_common!(LstmDescriptor);
1213impl LstmDescriptor {
1214    #[must_use]
1215    pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1216        let ptr = unsafe {
1217            ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels)
1218        };
1219        if ptr.is_null() {
1220            None
1221        } else {
1222            Some(Self { ptr })
1223        }
1224    }
1225
1226    #[must_use]
1227    pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1228        retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1229    }
1230
1231    #[must_use]
1232    pub fn memory_weights_are_diagonal(&self) -> bool {
1233        unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1234    }
1235
1236    pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1237        unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1238    }
1239
1240    #[must_use]
1241    pub fn cell_to_output_neuron_type(&self) -> usize {
1242        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1243    }
1244
1245    pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1246        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1247    }
1248
1249    #[must_use]
1250    pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1251        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1252    }
1253
1254    pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1255        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1256    }
1257
1258    #[must_use]
1259    pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1260        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1261    }
1262
1263    pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1264        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1265    }
1266
1267    #[must_use]
1268    pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1269        unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1270    }
1271
1272    pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1273        unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1274    }
1275}
1276
1277opaque_handle!(RnnRecurrentImageState);
1278impl RnnRecurrentImageState {
1279    #[must_use]
1280    pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1281        let ptr = unsafe {
1282            ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index)
1283        };
1284        if ptr.is_null() {
1285            None
1286        } else {
1287            Some(unsafe { Image::from_raw(ptr) })
1288        }
1289    }
1290
1291    #[must_use]
1292    pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1293        let ptr =
1294            unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1295        if ptr.is_null() {
1296            None
1297        } else {
1298            Some(unsafe { Image::from_raw(ptr) })
1299        }
1300    }
1301}
1302
1303opaque_handle!(RnnImageInferenceLayer);
1304impl RnnImageInferenceLayer {
1305    #[must_use]
1306    pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1307        let ptr =
1308            unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1309        if ptr.is_null() {
1310            None
1311        } else {
1312            Some(Self { ptr })
1313        }
1314    }
1315
1316    #[must_use]
1317    pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1318        let handles: Vec<_> = descriptors
1319            .iter()
1320            .map(|descriptor| descriptor.as_ptr())
1321            .collect();
1322        let handles_ptr = if handles.is_empty() {
1323            ptr::null()
1324        } else {
1325            handles.as_ptr()
1326        };
1327        let ptr = unsafe {
1328            ffi::mps_rnn_image_inference_layer_new_stack(
1329                device.as_ptr(),
1330                descriptors.len(),
1331                handles_ptr,
1332            )
1333        };
1334        if ptr.is_null() {
1335            None
1336        } else {
1337            Some(Self { ptr })
1338        }
1339    }
1340
1341    #[must_use]
1342    pub fn input_feature_channels(&self) -> usize {
1343        unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1344    }
1345
1346    #[must_use]
1347    pub fn output_feature_channels(&self) -> usize {
1348        unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1349    }
1350
1351    #[must_use]
1352    pub fn number_of_layers(&self) -> usize {
1353        unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1354    }
1355
1356    #[must_use]
1357    pub fn recurrent_output_is_temporary(&self) -> bool {
1358        unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1359    }
1360
1361    pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1362        unsafe {
1363            ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value);
1364        }
1365    }
1366
1367    #[must_use]
1368    pub fn store_all_intermediate_states(&self) -> bool {
1369        unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1370    }
1371
1372    pub fn set_store_all_intermediate_states(&self, value: bool) {
1373        unsafe {
1374            ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value);
1375        }
1376    }
1377
1378    #[must_use]
1379    pub fn bidirectional_combine_mode(&self) -> usize {
1380        unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1381    }
1382
1383    pub fn set_bidirectional_combine_mode(&self, value: usize) {
1384        unsafe {
1385            ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value);
1386        }
1387    }
1388
1389    #[must_use]
1390    pub fn encode_sequence(
1391        &self,
1392        command_buffer: &CommandBuffer,
1393        source_images: &[&Image],
1394        destination_images: &[&Image],
1395        recurrent_input_state: Option<&RnnRecurrentImageState>,
1396    ) -> Option<RnnRecurrentImageState> {
1397        if source_images.len() != destination_images.len() {
1398            return None;
1399        }
1400        let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1401        let destination_handles: Vec<_> = destination_images
1402            .iter()
1403            .map(|image| image.as_ptr())
1404            .collect();
1405        let source_ptr = if source_handles.is_empty() {
1406            ptr::null()
1407        } else {
1408            source_handles.as_ptr()
1409        };
1410        let destination_ptr = if destination_handles.is_empty() {
1411            ptr::null()
1412        } else {
1413            destination_handles.as_ptr()
1414        };
1415        let recurrent_input_ptr =
1416            recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1417        let ptr = unsafe {
1418            ffi::mps_rnn_image_inference_layer_encode_sequence(
1419                self.ptr,
1420                command_buffer.as_ptr(),
1421                source_images.len(),
1422                source_ptr,
1423                destination_ptr,
1424                recurrent_input_ptr,
1425            )
1426        };
1427        if ptr.is_null() {
1428            None
1429        } else {
1430            Some(RnnRecurrentImageState { ptr })
1431        }
1432    }
1433}