1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8pub mod rnn_sequence_direction {
10 pub const FORWARD: usize = 0;
11 pub const BACKWARD: usize = 1;
12}
13
14pub mod rnn_bidirectional_combine_mode {
16 pub const NONE: usize = 0;
17 pub const ADD: usize = 1;
18 pub const CONCATENATE: usize = 2;
19}
20
21pub mod cnn_convolution_flags {
23 pub const NONE: usize = 0;
24}
25
26pub mod cnn_convolution_weights_layout {
28 pub const OHWI: u32 = 0;
29}
30
31pub mod cnn_accumulator_precision_option {
33 pub const HALF: usize = 0;
34 pub const FLOAT: usize = 1;
35}
36
37pub mod nn_regularization_type {
39 pub const NONE: usize = 0;
40 pub const L1: usize = 1;
41 pub const L2: usize = 2;
42}
43
44pub use crate::generated::neural::*;
45
46macro_rules! opaque_handle {
47 ($name:ident) => {
48 pub struct $name {
49 ptr: *mut c_void,
50 }
51
52 unsafe impl Send for $name {}
54 unsafe impl Sync for $name {}
56
57 impl Drop for $name {
58 fn drop(&mut self) {
59 if !self.ptr.is_null() {
60 unsafe { ffi::mps_object_release(self.ptr) };
62 self.ptr = ptr::null_mut();
63 }
64 }
65 }
66
67 impl $name {
68 #[must_use]
69 pub const fn as_ptr(&self) -> *mut c_void {
70 self.ptr
71 }
72 }
73 };
74}
75
76macro_rules! impl_filter_result_image {
77 ($name:ident) => {
78 impl $name {
79 #[must_use]
80 pub fn result_image(&self) -> Option<NNImageNode> {
81 let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
82 if ptr.is_null() {
83 None
84 } else {
85 Some(NNImageNode { ptr })
86 }
87 }
88 }
89 };
90}
91
92fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
93 let retained = unsafe { ffi::mps_object_retain(ptr) };
94 if retained.is_null() {
95 None
96 } else {
97 Some(retained)
98 }
99}
100
101macro_rules! impl_rnn_descriptor_common {
102 ($name:ident) => {
103 impl $name {
104 #[must_use]
105 pub fn input_feature_channels(&self) -> usize {
106 unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
107 }
108
109 pub fn set_input_feature_channels(&self, value: usize) {
110 unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
111 }
112
113 #[must_use]
114 pub fn output_feature_channels(&self) -> usize {
115 unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
116 }
117
118 pub fn set_output_feature_channels(&self, value: usize) {
119 unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
120 }
121
122 #[must_use]
123 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
124 unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
125 }
126
127 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
128 unsafe {
129 ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
130 };
131 }
132
133 #[must_use]
134 pub fn use_float32_weights(&self) -> bool {
135 unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
136 }
137
138 pub fn set_use_float32_weights(&self, value: bool) {
139 unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
140 }
141
142 #[must_use]
143 pub fn layer_sequence_direction(&self) -> usize {
144 unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
145 }
146
147 pub fn set_layer_sequence_direction(&self, value: usize) {
148 unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
149 }
150 }
151 };
152}
153
154macro_rules! impl_optimizer_common {
155 ($name:ident) => {
156 impl $name {
157 #[must_use]
158 pub fn learning_rate(&self) -> f32 {
159 unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
160 }
161
162 pub fn set_learning_rate(&self, value: f32) {
163 unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
164 }
165
166 #[must_use]
167 pub fn gradient_rescale(&self) -> f32 {
168 unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
169 }
170
171 #[must_use]
172 pub fn apply_gradient_clipping(&self) -> bool {
173 unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
174 }
175
176 pub fn set_apply_gradient_clipping(&self, value: bool) {
177 unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
178 }
179
180 #[must_use]
181 pub fn gradient_clip_max(&self) -> f32 {
182 unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
183 }
184
185 #[must_use]
186 pub fn gradient_clip_min(&self) -> f32 {
187 unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
188 }
189
190 #[must_use]
191 pub fn regularization_scale(&self) -> f32 {
192 unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
193 }
194
195 #[must_use]
196 pub fn regularization_type(&self) -> usize {
197 unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
198 }
199 }
200 };
201}
202
203opaque_handle!(NNImageNode);
204impl NNImageNode {
205 #[must_use]
206 pub fn new() -> Option<Self> {
207 let ptr = unsafe { ffi::mps_nn_image_node_new() };
208 if ptr.is_null() {
209 None
210 } else {
211 Some(Self { ptr })
212 }
213 }
214
215 #[must_use]
216 pub fn exported() -> Option<Self> {
217 let ptr = unsafe { ffi::mps_nn_image_node_exported() };
218 if ptr.is_null() {
219 None
220 } else {
221 Some(Self { ptr })
222 }
223 }
224
225 #[must_use]
226 pub fn format(&self) -> usize {
227 unsafe { ffi::mps_nn_image_node_format(self.ptr) }
228 }
229
230 pub fn set_format(&self, format: usize) {
231 unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
232 }
233
234 #[must_use]
235 pub fn export_from_graph(&self) -> bool {
236 unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
237 }
238
239 pub fn set_export_from_graph(&self, export: bool) {
240 unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
241 }
242
243 #[must_use]
244 pub fn synchronize_resource(&self) -> bool {
245 unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
246 }
247
248 pub fn set_synchronize_resource(&self, synchronize: bool) {
249 unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
250 }
251
252 pub fn use_default_allocator(&self) {
253 unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
254 }
255}
256
257opaque_handle!(CnnNeuronReluNode);
258impl CnnNeuronReluNode {
259 #[must_use]
260 pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
261 let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
262 if ptr.is_null() {
263 None
264 } else {
265 Some(Self { ptr })
266 }
267 }
268}
269impl_filter_result_image!(CnnNeuronReluNode);
270
271opaque_handle!(CnnPoolingMaxNode);
272impl CnnPoolingMaxNode {
273 #[must_use]
274 pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
275 let ptr =
276 unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
277 if ptr.is_null() {
278 None
279 } else {
280 Some(Self { ptr })
281 }
282 }
283}
284impl_filter_result_image!(CnnPoolingMaxNode);
285
286opaque_handle!(CnnSoftMaxNode);
287impl CnnSoftMaxNode {
288 #[must_use]
289 pub fn new(source: &NNImageNode) -> Option<Self> {
290 let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
291 if ptr.is_null() {
292 None
293 } else {
294 Some(Self { ptr })
295 }
296 }
297}
298impl_filter_result_image!(CnnSoftMaxNode);
299
300opaque_handle!(CnnUpsamplingNearestNode);
301impl CnnUpsamplingNearestNode {
302 #[must_use]
303 pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
304 let ptr =
305 unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
306 if ptr.is_null() {
307 None
308 } else {
309 Some(Self { ptr })
310 }
311 }
312}
313impl_filter_result_image!(CnnUpsamplingNearestNode);
314
315opaque_handle!(NNGraph);
316impl NNGraph {
317 #[must_use]
318 pub fn new(
319 device: &MetalDevice,
320 result_image: &NNImageNode,
321 result_image_is_needed: bool,
322 ) -> Option<Self> {
323 let ptr = unsafe {
324 ffi::mps_nn_graph_new(
325 device.as_ptr(),
326 result_image.as_ptr(),
327 result_image_is_needed,
328 )
329 };
330 if ptr.is_null() {
331 None
332 } else {
333 Some(Self { ptr })
334 }
335 }
336
337 #[must_use]
338 pub fn source_image_count(&self) -> usize {
339 unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
340 }
341
342 #[must_use]
343 pub fn format(&self) -> usize {
344 unsafe { ffi::mps_nn_graph_format(self.ptr) }
345 }
346
347 pub fn set_format(&self, format: usize) {
348 unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
349 }
350
351 pub fn set_output_state_is_temporary(&self, temporary: bool) {
352 unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
353 }
354
355 pub fn use_default_destination_image_allocator(&self) {
356 unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
357 }
358
359 pub fn reload_from_data_sources(&self) {
360 unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
361 }
362
363 #[must_use]
364 pub fn encode(
365 &self,
366 command_buffer: &CommandBuffer,
367 source_images: &[&Image],
368 ) -> Option<Image> {
369 let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
370 let source_handles = if handles.is_empty() {
371 ptr::null()
372 } else {
373 handles.as_ptr()
374 };
375 let ptr = unsafe {
376 ffi::mps_nn_graph_encode(
377 self.ptr,
378 command_buffer.as_ptr(),
379 source_images.len(),
380 source_handles,
381 )
382 };
383 if ptr.is_null() {
384 None
385 } else {
386 Some(unsafe { Image::from_raw(ptr) })
387 }
388 }
389}
390
391opaque_handle!(CnnConvolutionDescriptor);
392impl CnnConvolutionDescriptor {
393 #[must_use]
394 pub fn new(
395 kernel_width: usize,
396 kernel_height: usize,
397 input_feature_channels: usize,
398 output_feature_channels: usize,
399 ) -> Option<Self> {
400 let ptr = unsafe {
401 ffi::mps_cnn_convolution_descriptor_new(
402 kernel_width,
403 kernel_height,
404 input_feature_channels,
405 output_feature_channels,
406 )
407 };
408 if ptr.is_null() {
409 None
410 } else {
411 Some(Self { ptr })
412 }
413 }
414
415 #[must_use]
416 pub fn kernel_width(&self) -> usize {
417 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
418 }
419
420 #[must_use]
421 pub fn kernel_height(&self) -> usize {
422 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
423 }
424
425 #[must_use]
426 pub fn input_feature_channels(&self) -> usize {
427 unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
428 }
429
430 #[must_use]
431 pub fn output_feature_channels(&self) -> usize {
432 unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
433 }
434
435 #[must_use]
436 pub fn stride_in_pixels_x(&self) -> usize {
437 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
438 }
439
440 pub fn set_stride_in_pixels_x(&self, value: usize) {
441 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
442 }
443
444 #[must_use]
445 pub fn stride_in_pixels_y(&self) -> usize {
446 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
447 }
448
449 pub fn set_stride_in_pixels_y(&self, value: usize) {
450 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
451 }
452
453 #[must_use]
454 pub fn groups(&self) -> usize {
455 unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
456 }
457
458 pub fn set_groups(&self, value: usize) {
459 unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
460 }
461
462 #[must_use]
463 pub fn dilation_rate_x(&self) -> usize {
464 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
465 }
466
467 pub fn set_dilation_rate_x(&self, value: usize) {
468 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
469 }
470
471 #[must_use]
472 pub fn dilation_rate_y(&self) -> usize {
473 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
474 }
475
476 pub fn set_dilation_rate_y(&self, value: usize) {
477 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
478 }
479}
480
481opaque_handle!(RnnSingleGateDescriptor);
482impl RnnSingleGateDescriptor {
483 #[must_use]
484 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
485 let ptr = unsafe {
486 ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
487 };
488 if ptr.is_null() {
489 None
490 } else {
491 Some(Self { ptr })
492 }
493 }
494
495 #[must_use]
496 pub fn input_feature_channels(&self) -> usize {
497 unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
498 }
499
500 pub fn set_input_feature_channels(&self, value: usize) {
501 unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
502 }
503
504 #[must_use]
505 pub fn output_feature_channels(&self) -> usize {
506 unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
507 }
508
509 pub fn set_output_feature_channels(&self, value: usize) {
510 unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
511 }
512
513 #[must_use]
514 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
515 unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
516 }
517
518 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
519 unsafe {
520 ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
521 self.ptr, value,
522 );
523 };
524 }
525
526 #[must_use]
527 pub fn use_float32_weights(&self) -> bool {
528 unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
529 }
530
531 pub fn set_use_float32_weights(&self, value: bool) {
532 unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
533 }
534
535 #[must_use]
536 pub fn layer_sequence_direction(&self) -> usize {
537 unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
538 }
539
540 pub fn set_layer_sequence_direction(&self, value: usize) {
541 unsafe {
542 ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
543 };
544 }
545
546 #[must_use]
547 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
548 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
549 }
550}
551
552opaque_handle!(CnnConvolution);
553impl CnnConvolution {
554 #[must_use]
555 pub fn new(
556 device: &MetalDevice,
557 descriptor: &CnnConvolutionDescriptor,
558 kernel_weights: &[f32],
559 bias_terms: Option<&[f32]>,
560 flags: usize,
561 ) -> Option<Self> {
562 if kernel_weights.is_empty() {
563 return None;
564 }
565 let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
566 let ptr = unsafe {
567 ffi::mps_cnn_convolution_new(
568 device.as_ptr(),
569 descriptor.as_ptr(),
570 kernel_weights.as_ptr(),
571 bias_terms_ptr,
572 flags,
573 )
574 };
575 if ptr.is_null() {
576 None
577 } else {
578 Some(Self { ptr })
579 }
580 }
581
582 #[must_use]
583 pub fn input_feature_channels(&self) -> usize {
584 unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
585 }
586
587 #[must_use]
588 pub fn output_feature_channels(&self) -> usize {
589 unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
590 }
591
592 #[must_use]
593 pub fn groups(&self) -> usize {
594 unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
595 }
596
597 #[must_use]
598 pub fn sub_pixel_scale_factor(&self) -> usize {
599 unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
600 }
601
602 #[must_use]
603 pub fn channel_multiplier(&self) -> usize {
604 unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
605 }
606
607 #[must_use]
608 pub fn accumulator_precision_option(&self) -> usize {
609 unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
610 }
611
612 pub fn set_accumulator_precision_option(&self, value: usize) {
613 unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
614 }
615
616 pub fn encode_image(
617 &self,
618 command_buffer: &CommandBuffer,
619 source: &Image,
620 destination: &Image,
621 ) {
622 unsafe {
623 ffi::mps_cnn_convolution_encode_image(
624 self.ptr,
625 command_buffer.as_ptr(),
626 source.as_ptr(),
627 destination.as_ptr(),
628 );
629 };
630 }
631}
632
633opaque_handle!(CnnConvolutionWeightsAndBiasesState);
634impl CnnConvolutionWeightsAndBiasesState {
635 #[must_use]
636 pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
637 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
638 let ptr = unsafe {
639 ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr)
640 };
641 if ptr.is_null() {
642 None
643 } else {
644 Some(Self { ptr })
645 }
646 }
647
648 #[must_use]
649 pub fn new_with_offsets(
650 weights: &MetalBuffer,
651 weights_offset: usize,
652 biases: Option<&MetalBuffer>,
653 biases_offset: usize,
654 descriptor: &CnnConvolutionDescriptor,
655 ) -> Option<Self> {
656 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
657 let ptr = unsafe {
658 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
659 weights.as_ptr(),
660 weights_offset,
661 biases_ptr,
662 biases_offset,
663 descriptor.as_ptr(),
664 )
665 };
666 if ptr.is_null() {
667 None
668 } else {
669 Some(Self { ptr })
670 }
671 }
672
673 #[must_use]
674 pub fn new_with_device(
675 device: &MetalDevice,
676 descriptor: &CnnConvolutionDescriptor,
677 ) -> Option<Self> {
678 let ptr = unsafe {
679 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
680 device.as_ptr(),
681 descriptor.as_ptr(),
682 )
683 };
684 if ptr.is_null() {
685 None
686 } else {
687 Some(Self { ptr })
688 }
689 }
690
691 #[must_use]
692 pub fn weights_offset(&self) -> usize {
693 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
694 }
695
696 #[must_use]
697 pub fn biases_offset(&self) -> usize {
698 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
699 }
700}
701
702opaque_handle!(NNOptimizerDescriptor);
703impl NNOptimizerDescriptor {
704 #[must_use]
705 pub fn new(
706 learning_rate: f32,
707 gradient_rescale: f32,
708 regularization_type: usize,
709 regularization_scale: f32,
710 ) -> Option<Self> {
711 let ptr = unsafe {
712 ffi::mps_nn_optimizer_descriptor_new(
713 learning_rate,
714 gradient_rescale,
715 regularization_type,
716 regularization_scale,
717 )
718 };
719 if ptr.is_null() {
720 None
721 } else {
722 Some(Self { ptr })
723 }
724 }
725
726 #[must_use]
727 pub fn with_gradient_clipping(
728 learning_rate: f32,
729 gradient_rescale: f32,
730 apply_gradient_clipping: bool,
731 gradient_clip_max: f32,
732 gradient_clip_min: f32,
733 regularization_type: usize,
734 regularization_scale: f32,
735 ) -> Option<Self> {
736 let ptr = unsafe {
737 ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
738 learning_rate,
739 gradient_rescale,
740 apply_gradient_clipping,
741 gradient_clip_max,
742 gradient_clip_min,
743 regularization_type,
744 regularization_scale,
745 )
746 };
747 if ptr.is_null() {
748 None
749 } else {
750 Some(Self { ptr })
751 }
752 }
753
754 #[must_use]
755 pub fn learning_rate(&self) -> f32 {
756 unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
757 }
758
759 pub fn set_learning_rate(&self, value: f32) {
760 unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
761 }
762
763 #[must_use]
764 pub fn gradient_rescale(&self) -> f32 {
765 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
766 }
767
768 pub fn set_gradient_rescale(&self, value: f32) {
769 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
770 }
771
772 #[must_use]
773 pub fn apply_gradient_clipping(&self) -> bool {
774 unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
775 }
776
777 pub fn set_apply_gradient_clipping(&self, value: bool) {
778 unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
779 }
780
781 #[must_use]
782 pub fn gradient_clip_max(&self) -> f32 {
783 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
784 }
785
786 pub fn set_gradient_clip_max(&self, value: f32) {
787 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
788 }
789
790 #[must_use]
791 pub fn gradient_clip_min(&self) -> f32 {
792 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
793 }
794
795 pub fn set_gradient_clip_min(&self, value: f32) {
796 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
797 }
798
799 #[must_use]
800 pub fn regularization_scale(&self) -> f32 {
801 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
802 }
803
804 pub fn set_regularization_scale(&self, value: f32) {
805 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
806 }
807
808 #[must_use]
809 pub fn regularization_type(&self) -> usize {
810 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
811 }
812
813 pub fn set_regularization_type(&self, value: usize) {
814 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
815 }
816}
817
818opaque_handle!(NNOptimizer);
819impl_optimizer_common!(NNOptimizer);
820
821opaque_handle!(NNOptimizerStochasticGradientDescent);
822impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
823impl NNOptimizerStochasticGradientDescent {
824 #[must_use]
825 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
826 let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
827 if ptr.is_null() {
828 None
829 } else {
830 Some(Self { ptr })
831 }
832 }
833
834 #[must_use]
835 pub fn new_with_options(
836 device: &MetalDevice,
837 momentum_scale: f32,
838 use_nesterov_momentum: bool,
839 optimizer_descriptor: &NNOptimizerDescriptor,
840 ) -> Option<Self> {
841 let ptr = unsafe {
842 ffi::mps_nn_optimizer_sgd_new_with_options(
843 device.as_ptr(),
844 momentum_scale,
845 use_nesterov_momentum,
846 optimizer_descriptor.as_ptr(),
847 )
848 };
849 if ptr.is_null() {
850 None
851 } else {
852 Some(Self { ptr })
853 }
854 }
855
856 #[must_use]
857 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
858 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
859 }
860
861 #[must_use]
862 pub fn momentum_scale(&self) -> f32 {
863 unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
864 }
865
866 #[must_use]
867 pub fn use_nesterov_momentum(&self) -> bool {
868 unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
869 }
870
871 pub fn encode_vector(
872 &self,
873 command_buffer: &CommandBuffer,
874 input_gradient_vector: &Vector,
875 input_values_vector: &Vector,
876 input_momentum_vector: Option<&Vector>,
877 result_values_vector: &Vector,
878 ) {
879 let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
880 unsafe {
881 ffi::mps_nn_optimizer_sgd_encode_vector(
882 self.ptr,
883 command_buffer.as_ptr(),
884 input_gradient_vector.as_ptr(),
885 input_values_vector.as_ptr(),
886 input_momentum_ptr,
887 result_values_vector.as_ptr(),
888 );
889 };
890 }
891
892 pub fn encode_matrix(
893 &self,
894 command_buffer: &CommandBuffer,
895 input_gradient_matrix: &Matrix,
896 input_values_matrix: &Matrix,
897 input_momentum_matrix: Option<&Matrix>,
898 result_values_matrix: &Matrix,
899 ) {
900 let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
901 unsafe {
902 ffi::mps_nn_optimizer_sgd_encode_matrix(
903 self.ptr,
904 command_buffer.as_ptr(),
905 input_gradient_matrix.as_ptr(),
906 input_values_matrix.as_ptr(),
907 input_momentum_ptr,
908 result_values_matrix.as_ptr(),
909 );
910 };
911 }
912}
913
914opaque_handle!(NNOptimizerRmsProp);
915impl_optimizer_common!(NNOptimizerRmsProp);
916impl NNOptimizerRmsProp {
917 #[must_use]
918 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
919 let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
920 if ptr.is_null() {
921 None
922 } else {
923 Some(Self { ptr })
924 }
925 }
926
927 #[must_use]
928 pub fn new_with_options(
929 device: &MetalDevice,
930 decay: f64,
931 epsilon: f32,
932 optimizer_descriptor: &NNOptimizerDescriptor,
933 ) -> Option<Self> {
934 let ptr = unsafe {
935 ffi::mps_nn_optimizer_rmsprop_new_with_options(
936 device.as_ptr(),
937 decay,
938 epsilon,
939 optimizer_descriptor.as_ptr(),
940 )
941 };
942 if ptr.is_null() {
943 None
944 } else {
945 Some(Self { ptr })
946 }
947 }
948
949 #[must_use]
950 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
951 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
952 }
953
954 #[must_use]
955 pub fn decay(&self) -> f64 {
956 unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
957 }
958
959 #[must_use]
960 pub fn epsilon(&self) -> f32 {
961 unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
962 }
963
964 pub fn encode_vector(
965 &self,
966 command_buffer: &CommandBuffer,
967 input_gradient_vector: &Vector,
968 input_values_vector: &Vector,
969 input_sum_of_squares_vector: &Vector,
970 result_values_vector: &Vector,
971 ) {
972 unsafe {
973 ffi::mps_nn_optimizer_rmsprop_encode_vector(
974 self.ptr,
975 command_buffer.as_ptr(),
976 input_gradient_vector.as_ptr(),
977 input_values_vector.as_ptr(),
978 input_sum_of_squares_vector.as_ptr(),
979 result_values_vector.as_ptr(),
980 );
981 };
982 }
983
984 pub fn encode_matrix(
985 &self,
986 command_buffer: &CommandBuffer,
987 input_gradient_matrix: &Matrix,
988 input_values_matrix: &Matrix,
989 input_sum_of_squares_matrix: &Matrix,
990 result_values_matrix: &Matrix,
991 ) {
992 unsafe {
993 ffi::mps_nn_optimizer_rmsprop_encode_matrix(
994 self.ptr,
995 command_buffer.as_ptr(),
996 input_gradient_matrix.as_ptr(),
997 input_values_matrix.as_ptr(),
998 input_sum_of_squares_matrix.as_ptr(),
999 result_values_matrix.as_ptr(),
1000 );
1001 };
1002 }
1003}
1004
1005opaque_handle!(NNOptimizerAdam);
1006impl_optimizer_common!(NNOptimizerAdam);
1007impl NNOptimizerAdam {
1008 #[must_use]
1009 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1010 let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
1011 if ptr.is_null() {
1012 None
1013 } else {
1014 Some(Self { ptr })
1015 }
1016 }
1017
1018 #[must_use]
1019 pub fn new_with_options(
1020 device: &MetalDevice,
1021 beta1: f64,
1022 beta2: f64,
1023 epsilon: f32,
1024 time_step: usize,
1025 optimizer_descriptor: &NNOptimizerDescriptor,
1026 ) -> Option<Self> {
1027 let ptr = unsafe {
1028 ffi::mps_nn_optimizer_adam_new_with_options(
1029 device.as_ptr(),
1030 beta1,
1031 beta2,
1032 epsilon,
1033 time_step,
1034 optimizer_descriptor.as_ptr(),
1035 )
1036 };
1037 if ptr.is_null() {
1038 None
1039 } else {
1040 Some(Self { ptr })
1041 }
1042 }
1043
1044 #[must_use]
1045 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1046 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1047 }
1048
1049 #[must_use]
1050 pub fn beta1(&self) -> f64 {
1051 unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1052 }
1053
1054 #[must_use]
1055 pub fn beta2(&self) -> f64 {
1056 unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1057 }
1058
1059 #[must_use]
1060 pub fn epsilon(&self) -> f32 {
1061 unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1062 }
1063
1064 #[must_use]
1065 pub fn time_step(&self) -> usize {
1066 unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1067 }
1068
1069 pub fn set_time_step(&self, value: usize) {
1070 unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1071 }
1072
1073 pub fn encode_vector(
1074 &self,
1075 command_buffer: &CommandBuffer,
1076 input_gradient_vector: &Vector,
1077 input_values_vector: &Vector,
1078 input_momentum_vector: &Vector,
1079 input_velocity_vector: &Vector,
1080 result_values_vector: &Vector,
1081 ) {
1082 unsafe {
1083 ffi::mps_nn_optimizer_adam_encode_vector(
1084 self.ptr,
1085 command_buffer.as_ptr(),
1086 input_gradient_vector.as_ptr(),
1087 input_values_vector.as_ptr(),
1088 input_momentum_vector.as_ptr(),
1089 input_velocity_vector.as_ptr(),
1090 result_values_vector.as_ptr(),
1091 );
1092 };
1093 }
1094
1095 pub fn encode_matrix(
1096 &self,
1097 command_buffer: &CommandBuffer,
1098 input_gradient_matrix: &Matrix,
1099 input_values_matrix: &Matrix,
1100 input_momentum_matrix: &Matrix,
1101 input_velocity_matrix: &Matrix,
1102 result_values_matrix: &Matrix,
1103 ) {
1104 unsafe {
1105 ffi::mps_nn_optimizer_adam_encode_matrix(
1106 self.ptr,
1107 command_buffer.as_ptr(),
1108 input_gradient_matrix.as_ptr(),
1109 input_values_matrix.as_ptr(),
1110 input_momentum_matrix.as_ptr(),
1111 input_velocity_matrix.as_ptr(),
1112 result_values_matrix.as_ptr(),
1113 );
1114 };
1115 }
1116
1117 #[allow(clippy::too_many_arguments)]
1118 pub fn encode_amsgrad_vector(
1119 &self,
1120 command_buffer: &CommandBuffer,
1121 input_gradient_vector: &Vector,
1122 input_values_vector: &Vector,
1123 input_momentum_vector: &Vector,
1124 input_velocity_vector: &Vector,
1125 maximum_velocity_vector: Option<&Vector>,
1126 result_values_vector: &Vector,
1127 ) {
1128 let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1129 unsafe {
1130 ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1131 self.ptr,
1132 command_buffer.as_ptr(),
1133 input_gradient_vector.as_ptr(),
1134 input_values_vector.as_ptr(),
1135 input_momentum_vector.as_ptr(),
1136 input_velocity_vector.as_ptr(),
1137 maximum_velocity_ptr,
1138 result_values_vector.as_ptr(),
1139 );
1140 };
1141 }
1142
1143 #[allow(clippy::too_many_arguments)]
1144 pub fn encode_amsgrad_matrix(
1145 &self,
1146 command_buffer: &CommandBuffer,
1147 input_gradient_matrix: &Matrix,
1148 input_values_matrix: &Matrix,
1149 input_momentum_matrix: &Matrix,
1150 input_velocity_matrix: &Matrix,
1151 maximum_velocity_matrix: Option<&Matrix>,
1152 result_values_matrix: &Matrix,
1153 ) {
1154 let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1155 unsafe {
1156 ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1157 self.ptr,
1158 command_buffer.as_ptr(),
1159 input_gradient_matrix.as_ptr(),
1160 input_values_matrix.as_ptr(),
1161 input_momentum_matrix.as_ptr(),
1162 input_velocity_matrix.as_ptr(),
1163 maximum_velocity_ptr,
1164 result_values_matrix.as_ptr(),
1165 );
1166 };
1167 }
1168}
1169
1170opaque_handle!(RnnDescriptor);
1171impl_rnn_descriptor_common!(RnnDescriptor);
1172
1173opaque_handle!(GruDescriptor);
1174impl_rnn_descriptor_common!(GruDescriptor);
1175impl GruDescriptor {
1176 #[must_use]
1177 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1178 let ptr =
1179 unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1180 if ptr.is_null() {
1181 None
1182 } else {
1183 Some(Self { ptr })
1184 }
1185 }
1186
1187 #[must_use]
1188 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1189 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1190 }
1191
1192 #[must_use]
1193 pub fn gate_pnorm_value(&self) -> f32 {
1194 unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1195 }
1196
1197 pub fn set_gate_pnorm_value(&self, value: f32) {
1198 unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1199 }
1200
1201 #[must_use]
1202 pub fn flip_output_gates(&self) -> bool {
1203 unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1204 }
1205
1206 pub fn set_flip_output_gates(&self, value: bool) {
1207 unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1208 }
1209}
1210
1211opaque_handle!(LstmDescriptor);
1212impl_rnn_descriptor_common!(LstmDescriptor);
1213impl LstmDescriptor {
1214 #[must_use]
1215 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1216 let ptr = unsafe {
1217 ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels)
1218 };
1219 if ptr.is_null() {
1220 None
1221 } else {
1222 Some(Self { ptr })
1223 }
1224 }
1225
1226 #[must_use]
1227 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1228 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1229 }
1230
1231 #[must_use]
1232 pub fn memory_weights_are_diagonal(&self) -> bool {
1233 unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1234 }
1235
1236 pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1237 unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1238 }
1239
1240 #[must_use]
1241 pub fn cell_to_output_neuron_type(&self) -> usize {
1242 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1243 }
1244
1245 pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1246 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1247 }
1248
1249 #[must_use]
1250 pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1251 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1252 }
1253
1254 pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1255 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1256 }
1257
1258 #[must_use]
1259 pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1260 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1261 }
1262
1263 pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1264 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1265 }
1266
1267 #[must_use]
1268 pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1269 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1270 }
1271
1272 pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1273 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1274 }
1275}
1276
1277opaque_handle!(RnnRecurrentImageState);
1278impl RnnRecurrentImageState {
1279 #[must_use]
1280 pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1281 let ptr = unsafe {
1282 ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index)
1283 };
1284 if ptr.is_null() {
1285 None
1286 } else {
1287 Some(unsafe { Image::from_raw(ptr) })
1288 }
1289 }
1290
1291 #[must_use]
1292 pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1293 let ptr =
1294 unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1295 if ptr.is_null() {
1296 None
1297 } else {
1298 Some(unsafe { Image::from_raw(ptr) })
1299 }
1300 }
1301}
1302
1303opaque_handle!(RnnImageInferenceLayer);
1304impl RnnImageInferenceLayer {
1305 #[must_use]
1306 pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1307 let ptr =
1308 unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1309 if ptr.is_null() {
1310 None
1311 } else {
1312 Some(Self { ptr })
1313 }
1314 }
1315
1316 #[must_use]
1317 pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1318 let handles: Vec<_> = descriptors
1319 .iter()
1320 .map(|descriptor| descriptor.as_ptr())
1321 .collect();
1322 let handles_ptr = if handles.is_empty() {
1323 ptr::null()
1324 } else {
1325 handles.as_ptr()
1326 };
1327 let ptr = unsafe {
1328 ffi::mps_rnn_image_inference_layer_new_stack(
1329 device.as_ptr(),
1330 descriptors.len(),
1331 handles_ptr,
1332 )
1333 };
1334 if ptr.is_null() {
1335 None
1336 } else {
1337 Some(Self { ptr })
1338 }
1339 }
1340
1341 #[must_use]
1342 pub fn input_feature_channels(&self) -> usize {
1343 unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1344 }
1345
1346 #[must_use]
1347 pub fn output_feature_channels(&self) -> usize {
1348 unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1349 }
1350
1351 #[must_use]
1352 pub fn number_of_layers(&self) -> usize {
1353 unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1354 }
1355
1356 #[must_use]
1357 pub fn recurrent_output_is_temporary(&self) -> bool {
1358 unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1359 }
1360
1361 pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1362 unsafe {
1363 ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value);
1364 }
1365 }
1366
1367 #[must_use]
1368 pub fn store_all_intermediate_states(&self) -> bool {
1369 unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1370 }
1371
1372 pub fn set_store_all_intermediate_states(&self, value: bool) {
1373 unsafe {
1374 ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value);
1375 }
1376 }
1377
1378 #[must_use]
1379 pub fn bidirectional_combine_mode(&self) -> usize {
1380 unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1381 }
1382
1383 pub fn set_bidirectional_combine_mode(&self, value: usize) {
1384 unsafe {
1385 ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value);
1386 }
1387 }
1388
1389 #[must_use]
1390 pub fn encode_sequence(
1391 &self,
1392 command_buffer: &CommandBuffer,
1393 source_images: &[&Image],
1394 destination_images: &[&Image],
1395 recurrent_input_state: Option<&RnnRecurrentImageState>,
1396 ) -> Option<RnnRecurrentImageState> {
1397 if source_images.len() != destination_images.len() {
1398 return None;
1399 }
1400 let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1401 let destination_handles: Vec<_> = destination_images
1402 .iter()
1403 .map(|image| image.as_ptr())
1404 .collect();
1405 let source_ptr = if source_handles.is_empty() {
1406 ptr::null()
1407 } else {
1408 source_handles.as_ptr()
1409 };
1410 let destination_ptr = if destination_handles.is_empty() {
1411 ptr::null()
1412 } else {
1413 destination_handles.as_ptr()
1414 };
1415 let recurrent_input_ptr =
1416 recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1417 let ptr = unsafe {
1418 ffi::mps_rnn_image_inference_layer_encode_sequence(
1419 self.ptr,
1420 command_buffer.as_ptr(),
1421 source_images.len(),
1422 source_ptr,
1423 destination_ptr,
1424 recurrent_input_ptr,
1425 )
1426 };
1427 if ptr.is_null() {
1428 None
1429 } else {
1430 Some(RnnRecurrentImageState { ptr })
1431 }
1432 }
1433}