1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8pub mod rnn_sequence_direction {
10 pub const FORWARD: usize = 0;
12 pub const BACKWARD: usize = 1;
14}
15
16pub mod rnn_bidirectional_combine_mode {
18 pub const NONE: usize = 0;
20 pub const ADD: usize = 1;
22 pub const CONCATENATE: usize = 2;
24}
25
26pub mod cnn_convolution_flags {
28 pub const NONE: usize = 0;
30}
31
32pub mod cnn_convolution_weights_layout {
34 pub const OHWI: u32 = 0;
36}
37
38pub mod cnn_accumulator_precision_option {
40 pub const HALF: usize = 0;
42 pub const FLOAT: usize = 1;
44}
45
46pub mod nn_regularization_type {
48 pub const NONE: usize = 0;
50 pub const L1: usize = 1;
52 pub const L2: usize = 2;
54}
55
56#[doc(hidden)]
57pub use crate::generated::neural::*;
58
59macro_rules! opaque_handle {
60 ($name:ident, $doc:expr) => {
61 #[doc = $doc]
62 pub struct $name {
63 ptr: *mut c_void,
64 }
65
66 unsafe impl Send for $name {}
68 unsafe impl Sync for $name {}
70
71 impl Drop for $name {
72 fn drop(&mut self) {
73 if !self.ptr.is_null() {
74 unsafe { ffi::mps_object_release(self.ptr) };
76 self.ptr = ptr::null_mut();
77 }
78 }
79 }
80
81 impl $name {
82 #[must_use]
84 pub const fn as_ptr(&self) -> *mut c_void {
85 self.ptr
86 }
87 }
88 };
89}
90
91macro_rules! impl_filter_result_image {
92 ($name:ident) => {
93 impl $name {
94 #[must_use]
96 pub fn result_image(&self) -> Option<NNImageNode> {
97 let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
98 if ptr.is_null() {
99 None
100 } else {
101 Some(NNImageNode { ptr })
102 }
103 }
104 }
105 };
106}
107
108fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
109 let retained = unsafe { ffi::mps_object_retain(ptr) };
110 if retained.is_null() {
111 None
112 } else {
113 Some(retained)
114 }
115}
116
117macro_rules! impl_rnn_descriptor_common {
118 ($name:ident) => {
119 impl $name {
120 #[must_use]
122 pub fn input_feature_channels(&self) -> usize {
123 unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
124 }
125
126 pub fn set_input_feature_channels(&self, value: usize) {
128 unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
129 }
130
131 #[must_use]
133 pub fn output_feature_channels(&self) -> usize {
134 unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
135 }
136
137 pub fn set_output_feature_channels(&self, value: usize) {
139 unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
140 }
141
142 #[must_use]
144 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
145 unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
146 }
147
148 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
150 unsafe {
151 ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
152 };
153 }
154
155 #[must_use]
157 pub fn use_float32_weights(&self) -> bool {
158 unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
159 }
160
161 pub fn set_use_float32_weights(&self, value: bool) {
163 unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
164 }
165
166 #[must_use]
168 pub fn layer_sequence_direction(&self) -> usize {
169 unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
170 }
171
172 pub fn set_layer_sequence_direction(&self, value: usize) {
174 unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
175 }
176 }
177 };
178}
179
180macro_rules! impl_optimizer_common {
181 ($name:ident) => {
182 impl $name {
183 #[must_use]
185 pub fn learning_rate(&self) -> f32 {
186 unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
187 }
188
189 pub fn set_learning_rate(&self, value: f32) {
191 unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
192 }
193
194 #[must_use]
196 pub fn gradient_rescale(&self) -> f32 {
197 unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
198 }
199
200 #[must_use]
202 pub fn apply_gradient_clipping(&self) -> bool {
203 unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
204 }
205
206 pub fn set_apply_gradient_clipping(&self, value: bool) {
208 unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
209 }
210
211 #[must_use]
213 pub fn gradient_clip_max(&self) -> f32 {
214 unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
215 }
216
217 #[must_use]
219 pub fn gradient_clip_min(&self) -> f32 {
220 unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
221 }
222
223 #[must_use]
225 pub fn regularization_scale(&self) -> f32 {
226 unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
227 }
228
229 #[must_use]
231 pub fn regularization_type(&self) -> usize {
232 unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
233 }
234 }
235 };
236}
237
238opaque_handle!(NNImageNode, "Wraps `MPSNNImageNode`.");
239impl NNImageNode {
240 #[must_use]
242 pub fn new() -> Option<Self> {
243 let ptr = unsafe { ffi::mps_nn_image_node_new() };
244 if ptr.is_null() {
245 None
246 } else {
247 Some(Self { ptr })
248 }
249 }
250
251 #[must_use]
253 pub fn exported() -> Option<Self> {
254 let ptr = unsafe { ffi::mps_nn_image_node_exported() };
255 if ptr.is_null() {
256 None
257 } else {
258 Some(Self { ptr })
259 }
260 }
261
262 #[must_use]
264 pub fn format(&self) -> usize {
265 unsafe { ffi::mps_nn_image_node_format(self.ptr) }
266 }
267
268 pub fn set_format(&self, format: usize) {
270 unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
271 }
272
273 #[must_use]
275 pub fn export_from_graph(&self) -> bool {
276 unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
277 }
278
279 pub fn set_export_from_graph(&self, export: bool) {
281 unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
282 }
283
284 #[must_use]
286 pub fn synchronize_resource(&self) -> bool {
287 unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
288 }
289
290 pub fn set_synchronize_resource(&self, synchronize: bool) {
292 unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
293 }
294
295 pub fn use_default_allocator(&self) {
297 unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
298 }
299}
300
301opaque_handle!(CnnNeuronReluNode, "Wraps `MPSCNNNeuronReLUNode`.");
302impl CnnNeuronReluNode {
303 #[must_use]
305 pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
306 let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
307 if ptr.is_null() {
308 None
309 } else {
310 Some(Self { ptr })
311 }
312 }
313}
314impl_filter_result_image!(CnnNeuronReluNode);
315
316opaque_handle!(CnnPoolingMaxNode, "Wraps `MPSCNNPoolingMaxNode`.");
317impl CnnPoolingMaxNode {
318 #[must_use]
320 pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
321 let ptr =
322 unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
323 if ptr.is_null() {
324 None
325 } else {
326 Some(Self { ptr })
327 }
328 }
329}
330impl_filter_result_image!(CnnPoolingMaxNode);
331
332opaque_handle!(CnnSoftMaxNode, "Wraps `MPSCNNSoftMaxNode`.");
333impl CnnSoftMaxNode {
334 #[must_use]
336 pub fn new(source: &NNImageNode) -> Option<Self> {
337 let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
338 if ptr.is_null() {
339 None
340 } else {
341 Some(Self { ptr })
342 }
343 }
344}
345impl_filter_result_image!(CnnSoftMaxNode);
346
347opaque_handle!(CnnUpsamplingNearestNode, "Wraps `MPSCNNUpsamplingNearestNode`.");
348impl CnnUpsamplingNearestNode {
349 #[must_use]
351 pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
352 let ptr =
353 unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
354 if ptr.is_null() {
355 None
356 } else {
357 Some(Self { ptr })
358 }
359 }
360}
361impl_filter_result_image!(CnnUpsamplingNearestNode);
362
363opaque_handle!(NNGraph, "Wraps `MPSNNGraph`.");
364impl NNGraph {
365 #[must_use]
367 pub fn new(
368 device: &MetalDevice,
369 result_image: &NNImageNode,
370 result_image_is_needed: bool,
371 ) -> Option<Self> {
372 let ptr = unsafe {
373 ffi::mps_nn_graph_new(
374 device.as_ptr(),
375 result_image.as_ptr(),
376 result_image_is_needed,
377 )
378 };
379 if ptr.is_null() {
380 None
381 } else {
382 Some(Self { ptr })
383 }
384 }
385
386 #[must_use]
388 pub fn source_image_count(&self) -> usize {
389 unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
390 }
391
392 #[must_use]
394 pub fn format(&self) -> usize {
395 unsafe { ffi::mps_nn_graph_format(self.ptr) }
396 }
397
398 pub fn set_format(&self, format: usize) {
400 unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
401 }
402
403 pub fn set_output_state_is_temporary(&self, temporary: bool) {
405 unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
406 }
407
408 pub fn use_default_destination_image_allocator(&self) {
410 unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
411 }
412
413 pub fn reload_from_data_sources(&self) {
415 unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
416 }
417
418 #[must_use]
420 pub fn encode(
421 &self,
422 command_buffer: &CommandBuffer,
423 source_images: &[&Image],
424 ) -> Option<Image> {
425 let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
426 let source_handles = if handles.is_empty() {
427 ptr::null()
428 } else {
429 handles.as_ptr()
430 };
431 let ptr = unsafe {
432 ffi::mps_nn_graph_encode(
433 self.ptr,
434 command_buffer.as_ptr(),
435 source_images.len(),
436 source_handles,
437 )
438 };
439 if ptr.is_null() {
440 None
441 } else {
442 Some(unsafe { Image::from_raw(ptr) })
443 }
444 }
445}
446
447opaque_handle!(CnnConvolutionDescriptor, "Wraps `MPSCNNConvolutionDescriptor`.");
448impl CnnConvolutionDescriptor {
449 #[must_use]
451 pub fn new(
452 kernel_width: usize,
453 kernel_height: usize,
454 input_feature_channels: usize,
455 output_feature_channels: usize,
456 ) -> Option<Self> {
457 let ptr = unsafe {
458 ffi::mps_cnn_convolution_descriptor_new(
459 kernel_width,
460 kernel_height,
461 input_feature_channels,
462 output_feature_channels,
463 )
464 };
465 if ptr.is_null() {
466 None
467 } else {
468 Some(Self { ptr })
469 }
470 }
471
472 #[must_use]
474 pub fn kernel_width(&self) -> usize {
475 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
476 }
477
478 #[must_use]
480 pub fn kernel_height(&self) -> usize {
481 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
482 }
483
484 #[must_use]
486 pub fn input_feature_channels(&self) -> usize {
487 unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
488 }
489
490 #[must_use]
492 pub fn output_feature_channels(&self) -> usize {
493 unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
494 }
495
496 #[must_use]
498 pub fn stride_in_pixels_x(&self) -> usize {
499 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
500 }
501
502 pub fn set_stride_in_pixels_x(&self, value: usize) {
504 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
505 }
506
507 #[must_use]
509 pub fn stride_in_pixels_y(&self) -> usize {
510 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
511 }
512
513 pub fn set_stride_in_pixels_y(&self, value: usize) {
515 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
516 }
517
518 #[must_use]
520 pub fn groups(&self) -> usize {
521 unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
522 }
523
524 pub fn set_groups(&self, value: usize) {
526 unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
527 }
528
529 #[must_use]
531 pub fn dilation_rate_x(&self) -> usize {
532 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
533 }
534
535 pub fn set_dilation_rate_x(&self, value: usize) {
537 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
538 }
539
540 #[must_use]
542 pub fn dilation_rate_y(&self) -> usize {
543 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
544 }
545
546 pub fn set_dilation_rate_y(&self, value: usize) {
548 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
549 }
550}
551
552opaque_handle!(RnnSingleGateDescriptor, "Wraps `MPSRNNSingleGateDescriptor`.");
553impl RnnSingleGateDescriptor {
554 #[must_use]
556 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
557 let ptr = unsafe {
558 ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
559 };
560 if ptr.is_null() {
561 None
562 } else {
563 Some(Self { ptr })
564 }
565 }
566
567 #[must_use]
569 pub fn input_feature_channels(&self) -> usize {
570 unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
571 }
572
573 pub fn set_input_feature_channels(&self, value: usize) {
575 unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
576 }
577
578 #[must_use]
580 pub fn output_feature_channels(&self) -> usize {
581 unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
582 }
583
584 pub fn set_output_feature_channels(&self, value: usize) {
586 unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
587 }
588
589 #[must_use]
591 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
592 unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
593 }
594
595 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
597 unsafe {
598 ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
599 self.ptr, value,
600 );
601 };
602 }
603
604 #[must_use]
606 pub fn use_float32_weights(&self) -> bool {
607 unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
608 }
609
610 pub fn set_use_float32_weights(&self, value: bool) {
612 unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
613 }
614
615 #[must_use]
617 pub fn layer_sequence_direction(&self) -> usize {
618 unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
619 }
620
621 pub fn set_layer_sequence_direction(&self, value: usize) {
623 unsafe {
624 ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
625 };
626 }
627
628 #[must_use]
630 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
631 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
632 }
633}
634
635opaque_handle!(CnnConvolution, "Wraps `MPSCNNConvolution`.");
636impl CnnConvolution {
637 #[must_use]
639 pub fn new(
640 device: &MetalDevice,
641 descriptor: &CnnConvolutionDescriptor,
642 kernel_weights: &[f32],
643 bias_terms: Option<&[f32]>,
644 flags: usize,
645 ) -> Option<Self> {
646 if kernel_weights.is_empty() {
647 return None;
648 }
649 let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
650 let ptr = unsafe {
651 ffi::mps_cnn_convolution_new(
652 device.as_ptr(),
653 descriptor.as_ptr(),
654 kernel_weights.as_ptr(),
655 bias_terms_ptr,
656 flags,
657 )
658 };
659 if ptr.is_null() {
660 None
661 } else {
662 Some(Self { ptr })
663 }
664 }
665
666 #[must_use]
668 pub fn input_feature_channels(&self) -> usize {
669 unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
670 }
671
672 #[must_use]
674 pub fn output_feature_channels(&self) -> usize {
675 unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
676 }
677
678 #[must_use]
680 pub fn groups(&self) -> usize {
681 unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
682 }
683
684 #[must_use]
686 pub fn sub_pixel_scale_factor(&self) -> usize {
687 unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
688 }
689
690 #[must_use]
692 pub fn channel_multiplier(&self) -> usize {
693 unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
694 }
695
696 #[must_use]
698 pub fn accumulator_precision_option(&self) -> usize {
699 unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
700 }
701
702 pub fn set_accumulator_precision_option(&self, value: usize) {
704 unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
705 }
706
707 pub fn encode_image(
709 &self,
710 command_buffer: &CommandBuffer,
711 source: &Image,
712 destination: &Image,
713 ) {
714 unsafe {
715 ffi::mps_cnn_convolution_encode_image(
716 self.ptr,
717 command_buffer.as_ptr(),
718 source.as_ptr(),
719 destination.as_ptr(),
720 );
721 };
722 }
723}
724
725opaque_handle!(CnnConvolutionWeightsAndBiasesState, "Wraps `MPSCNNConvolutionWeightsAndBiasesState`.");
726impl CnnConvolutionWeightsAndBiasesState {
727 #[must_use]
729 pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
730 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
731 let ptr = unsafe {
732 ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr)
733 };
734 if ptr.is_null() {
735 None
736 } else {
737 Some(Self { ptr })
738 }
739 }
740
741 #[must_use]
743 pub fn new_with_offsets(
744 weights: &MetalBuffer,
745 weights_offset: usize,
746 biases: Option<&MetalBuffer>,
747 biases_offset: usize,
748 descriptor: &CnnConvolutionDescriptor,
749 ) -> Option<Self> {
750 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
751 let ptr = unsafe {
752 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
753 weights.as_ptr(),
754 weights_offset,
755 biases_ptr,
756 biases_offset,
757 descriptor.as_ptr(),
758 )
759 };
760 if ptr.is_null() {
761 None
762 } else {
763 Some(Self { ptr })
764 }
765 }
766
767 #[must_use]
769 pub fn new_with_device(
770 device: &MetalDevice,
771 descriptor: &CnnConvolutionDescriptor,
772 ) -> Option<Self> {
773 let ptr = unsafe {
774 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
775 device.as_ptr(),
776 descriptor.as_ptr(),
777 )
778 };
779 if ptr.is_null() {
780 None
781 } else {
782 Some(Self { ptr })
783 }
784 }
785
786 #[must_use]
788 pub fn weights_offset(&self) -> usize {
789 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
790 }
791
792 #[must_use]
794 pub fn biases_offset(&self) -> usize {
795 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
796 }
797}
798
799opaque_handle!(NNOptimizerDescriptor, "Wraps `MPSNNOptimizerDescriptor`.");
800impl NNOptimizerDescriptor {
801 #[must_use]
803 pub fn new(
804 learning_rate: f32,
805 gradient_rescale: f32,
806 regularization_type: usize,
807 regularization_scale: f32,
808 ) -> Option<Self> {
809 let ptr = unsafe {
810 ffi::mps_nn_optimizer_descriptor_new(
811 learning_rate,
812 gradient_rescale,
813 regularization_type,
814 regularization_scale,
815 )
816 };
817 if ptr.is_null() {
818 None
819 } else {
820 Some(Self { ptr })
821 }
822 }
823
824 #[must_use]
826 pub fn with_gradient_clipping(
827 learning_rate: f32,
828 gradient_rescale: f32,
829 apply_gradient_clipping: bool,
830 gradient_clip_max: f32,
831 gradient_clip_min: f32,
832 regularization_type: usize,
833 regularization_scale: f32,
834 ) -> Option<Self> {
835 let ptr = unsafe {
836 ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
837 learning_rate,
838 gradient_rescale,
839 apply_gradient_clipping,
840 gradient_clip_max,
841 gradient_clip_min,
842 regularization_type,
843 regularization_scale,
844 )
845 };
846 if ptr.is_null() {
847 None
848 } else {
849 Some(Self { ptr })
850 }
851 }
852
853 #[must_use]
855 pub fn learning_rate(&self) -> f32 {
856 unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
857 }
858
859 pub fn set_learning_rate(&self, value: f32) {
861 unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
862 }
863
864 #[must_use]
866 pub fn gradient_rescale(&self) -> f32 {
867 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
868 }
869
870 pub fn set_gradient_rescale(&self, value: f32) {
872 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
873 }
874
875 #[must_use]
877 pub fn apply_gradient_clipping(&self) -> bool {
878 unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
879 }
880
881 pub fn set_apply_gradient_clipping(&self, value: bool) {
883 unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
884 }
885
886 #[must_use]
888 pub fn gradient_clip_max(&self) -> f32 {
889 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
890 }
891
892 pub fn set_gradient_clip_max(&self, value: f32) {
894 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
895 }
896
897 #[must_use]
899 pub fn gradient_clip_min(&self) -> f32 {
900 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
901 }
902
903 pub fn set_gradient_clip_min(&self, value: f32) {
905 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
906 }
907
908 #[must_use]
910 pub fn regularization_scale(&self) -> f32 {
911 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
912 }
913
914 pub fn set_regularization_scale(&self, value: f32) {
916 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
917 }
918
919 #[must_use]
921 pub fn regularization_type(&self) -> usize {
922 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
923 }
924
925 pub fn set_regularization_type(&self, value: usize) {
927 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
928 }
929}
930
931opaque_handle!(NNOptimizer, "Wraps `MPSNNOptimizer`.");
932impl_optimizer_common!(NNOptimizer);
933
934opaque_handle!(NNOptimizerStochasticGradientDescent, "Wraps `MPSNNOptimizerStochasticGradientDescent`.");
935impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
936impl NNOptimizerStochasticGradientDescent {
937 #[must_use]
939 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
940 let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
941 if ptr.is_null() {
942 None
943 } else {
944 Some(Self { ptr })
945 }
946 }
947
948 #[must_use]
950 pub fn new_with_options(
951 device: &MetalDevice,
952 momentum_scale: f32,
953 use_nesterov_momentum: bool,
954 optimizer_descriptor: &NNOptimizerDescriptor,
955 ) -> Option<Self> {
956 let ptr = unsafe {
957 ffi::mps_nn_optimizer_sgd_new_with_options(
958 device.as_ptr(),
959 momentum_scale,
960 use_nesterov_momentum,
961 optimizer_descriptor.as_ptr(),
962 )
963 };
964 if ptr.is_null() {
965 None
966 } else {
967 Some(Self { ptr })
968 }
969 }
970
971 #[must_use]
973 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
974 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
975 }
976
977 #[must_use]
979 pub fn momentum_scale(&self) -> f32 {
980 unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
981 }
982
983 #[must_use]
985 pub fn use_nesterov_momentum(&self) -> bool {
986 unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
987 }
988
989 pub fn encode_vector(
991 &self,
992 command_buffer: &CommandBuffer,
993 input_gradient_vector: &Vector,
994 input_values_vector: &Vector,
995 input_momentum_vector: Option<&Vector>,
996 result_values_vector: &Vector,
997 ) {
998 let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
999 unsafe {
1000 ffi::mps_nn_optimizer_sgd_encode_vector(
1001 self.ptr,
1002 command_buffer.as_ptr(),
1003 input_gradient_vector.as_ptr(),
1004 input_values_vector.as_ptr(),
1005 input_momentum_ptr,
1006 result_values_vector.as_ptr(),
1007 );
1008 };
1009 }
1010
1011 pub fn encode_matrix(
1013 &self,
1014 command_buffer: &CommandBuffer,
1015 input_gradient_matrix: &Matrix,
1016 input_values_matrix: &Matrix,
1017 input_momentum_matrix: Option<&Matrix>,
1018 result_values_matrix: &Matrix,
1019 ) {
1020 let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1021 unsafe {
1022 ffi::mps_nn_optimizer_sgd_encode_matrix(
1023 self.ptr,
1024 command_buffer.as_ptr(),
1025 input_gradient_matrix.as_ptr(),
1026 input_values_matrix.as_ptr(),
1027 input_momentum_ptr,
1028 result_values_matrix.as_ptr(),
1029 );
1030 };
1031 }
1032}
1033
1034opaque_handle!(NNOptimizerRmsProp, "Wraps `MPSNNOptimizerRMSProp`.");
1035impl_optimizer_common!(NNOptimizerRmsProp);
1036impl NNOptimizerRmsProp {
1037 #[must_use]
1039 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1040 let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
1041 if ptr.is_null() {
1042 None
1043 } else {
1044 Some(Self { ptr })
1045 }
1046 }
1047
1048 #[must_use]
1050 pub fn new_with_options(
1051 device: &MetalDevice,
1052 decay: f64,
1053 epsilon: f32,
1054 optimizer_descriptor: &NNOptimizerDescriptor,
1055 ) -> Option<Self> {
1056 let ptr = unsafe {
1057 ffi::mps_nn_optimizer_rmsprop_new_with_options(
1058 device.as_ptr(),
1059 decay,
1060 epsilon,
1061 optimizer_descriptor.as_ptr(),
1062 )
1063 };
1064 if ptr.is_null() {
1065 None
1066 } else {
1067 Some(Self { ptr })
1068 }
1069 }
1070
1071 #[must_use]
1073 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1074 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1075 }
1076
1077 #[must_use]
1079 pub fn decay(&self) -> f64 {
1080 unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
1081 }
1082
1083 #[must_use]
1085 pub fn epsilon(&self) -> f32 {
1086 unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
1087 }
1088
1089 pub fn encode_vector(
1091 &self,
1092 command_buffer: &CommandBuffer,
1093 input_gradient_vector: &Vector,
1094 input_values_vector: &Vector,
1095 input_sum_of_squares_vector: &Vector,
1096 result_values_vector: &Vector,
1097 ) {
1098 unsafe {
1099 ffi::mps_nn_optimizer_rmsprop_encode_vector(
1100 self.ptr,
1101 command_buffer.as_ptr(),
1102 input_gradient_vector.as_ptr(),
1103 input_values_vector.as_ptr(),
1104 input_sum_of_squares_vector.as_ptr(),
1105 result_values_vector.as_ptr(),
1106 );
1107 };
1108 }
1109
1110 pub fn encode_matrix(
1112 &self,
1113 command_buffer: &CommandBuffer,
1114 input_gradient_matrix: &Matrix,
1115 input_values_matrix: &Matrix,
1116 input_sum_of_squares_matrix: &Matrix,
1117 result_values_matrix: &Matrix,
1118 ) {
1119 unsafe {
1120 ffi::mps_nn_optimizer_rmsprop_encode_matrix(
1121 self.ptr,
1122 command_buffer.as_ptr(),
1123 input_gradient_matrix.as_ptr(),
1124 input_values_matrix.as_ptr(),
1125 input_sum_of_squares_matrix.as_ptr(),
1126 result_values_matrix.as_ptr(),
1127 );
1128 };
1129 }
1130}
1131
1132opaque_handle!(NNOptimizerAdam, "Wraps `MPSNNOptimizerAdam`.");
1133impl_optimizer_common!(NNOptimizerAdam);
1134impl NNOptimizerAdam {
1135 #[must_use]
1137 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
1138 let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
1139 if ptr.is_null() {
1140 None
1141 } else {
1142 Some(Self { ptr })
1143 }
1144 }
1145
1146 #[must_use]
1148 pub fn new_with_options(
1149 device: &MetalDevice,
1150 beta1: f64,
1151 beta2: f64,
1152 epsilon: f32,
1153 time_step: usize,
1154 optimizer_descriptor: &NNOptimizerDescriptor,
1155 ) -> Option<Self> {
1156 let ptr = unsafe {
1157 ffi::mps_nn_optimizer_adam_new_with_options(
1158 device.as_ptr(),
1159 beta1,
1160 beta2,
1161 epsilon,
1162 time_step,
1163 optimizer_descriptor.as_ptr(),
1164 )
1165 };
1166 if ptr.is_null() {
1167 None
1168 } else {
1169 Some(Self { ptr })
1170 }
1171 }
1172
1173 #[must_use]
1175 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1176 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1177 }
1178
1179 #[must_use]
1181 pub fn beta1(&self) -> f64 {
1182 unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1183 }
1184
1185 #[must_use]
1187 pub fn beta2(&self) -> f64 {
1188 unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1189 }
1190
1191 #[must_use]
1193 pub fn epsilon(&self) -> f32 {
1194 unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1195 }
1196
1197 #[must_use]
1199 pub fn time_step(&self) -> usize {
1200 unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1201 }
1202
1203 pub fn set_time_step(&self, value: usize) {
1205 unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1206 }
1207
1208 pub fn encode_vector(
1210 &self,
1211 command_buffer: &CommandBuffer,
1212 input_gradient_vector: &Vector,
1213 input_values_vector: &Vector,
1214 input_momentum_vector: &Vector,
1215 input_velocity_vector: &Vector,
1216 result_values_vector: &Vector,
1217 ) {
1218 unsafe {
1219 ffi::mps_nn_optimizer_adam_encode_vector(
1220 self.ptr,
1221 command_buffer.as_ptr(),
1222 input_gradient_vector.as_ptr(),
1223 input_values_vector.as_ptr(),
1224 input_momentum_vector.as_ptr(),
1225 input_velocity_vector.as_ptr(),
1226 result_values_vector.as_ptr(),
1227 );
1228 };
1229 }
1230
1231 pub fn encode_matrix(
1233 &self,
1234 command_buffer: &CommandBuffer,
1235 input_gradient_matrix: &Matrix,
1236 input_values_matrix: &Matrix,
1237 input_momentum_matrix: &Matrix,
1238 input_velocity_matrix: &Matrix,
1239 result_values_matrix: &Matrix,
1240 ) {
1241 unsafe {
1242 ffi::mps_nn_optimizer_adam_encode_matrix(
1243 self.ptr,
1244 command_buffer.as_ptr(),
1245 input_gradient_matrix.as_ptr(),
1246 input_values_matrix.as_ptr(),
1247 input_momentum_matrix.as_ptr(),
1248 input_velocity_matrix.as_ptr(),
1249 result_values_matrix.as_ptr(),
1250 );
1251 };
1252 }
1253
1254 #[allow(clippy::too_many_arguments)]
1256 pub fn encode_amsgrad_vector(
1257 &self,
1258 command_buffer: &CommandBuffer,
1259 input_gradient_vector: &Vector,
1260 input_values_vector: &Vector,
1261 input_momentum_vector: &Vector,
1262 input_velocity_vector: &Vector,
1263 maximum_velocity_vector: Option<&Vector>,
1264 result_values_vector: &Vector,
1265 ) {
1266 let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1267 unsafe {
1268 ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1269 self.ptr,
1270 command_buffer.as_ptr(),
1271 input_gradient_vector.as_ptr(),
1272 input_values_vector.as_ptr(),
1273 input_momentum_vector.as_ptr(),
1274 input_velocity_vector.as_ptr(),
1275 maximum_velocity_ptr,
1276 result_values_vector.as_ptr(),
1277 );
1278 };
1279 }
1280
1281 #[allow(clippy::too_many_arguments)]
1283 pub fn encode_amsgrad_matrix(
1284 &self,
1285 command_buffer: &CommandBuffer,
1286 input_gradient_matrix: &Matrix,
1287 input_values_matrix: &Matrix,
1288 input_momentum_matrix: &Matrix,
1289 input_velocity_matrix: &Matrix,
1290 maximum_velocity_matrix: Option<&Matrix>,
1291 result_values_matrix: &Matrix,
1292 ) {
1293 let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1294 unsafe {
1295 ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1296 self.ptr,
1297 command_buffer.as_ptr(),
1298 input_gradient_matrix.as_ptr(),
1299 input_values_matrix.as_ptr(),
1300 input_momentum_matrix.as_ptr(),
1301 input_velocity_matrix.as_ptr(),
1302 maximum_velocity_ptr,
1303 result_values_matrix.as_ptr(),
1304 );
1305 };
1306 }
1307}
1308
1309opaque_handle!(RnnDescriptor, "Wraps `MPSRNNDescriptor`.");
1310impl_rnn_descriptor_common!(RnnDescriptor);
1311
1312opaque_handle!(GruDescriptor, "Wraps `MPSGRUDescriptor`.");
1313impl_rnn_descriptor_common!(GruDescriptor);
1314impl GruDescriptor {
1315 #[must_use]
1317 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1318 let ptr =
1319 unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1320 if ptr.is_null() {
1321 None
1322 } else {
1323 Some(Self { ptr })
1324 }
1325 }
1326
1327 #[must_use]
1329 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1330 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1331 }
1332
1333 #[must_use]
1335 pub fn gate_pnorm_value(&self) -> f32 {
1336 unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1337 }
1338
1339 pub fn set_gate_pnorm_value(&self, value: f32) {
1341 unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1342 }
1343
1344 #[must_use]
1346 pub fn flip_output_gates(&self) -> bool {
1347 unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1348 }
1349
1350 pub fn set_flip_output_gates(&self, value: bool) {
1352 unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1353 }
1354}
1355
1356opaque_handle!(LstmDescriptor, "Wraps `MPSLSTMDescriptor`.");
1357impl_rnn_descriptor_common!(LstmDescriptor);
1358impl LstmDescriptor {
1359 #[must_use]
1361 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1362 let ptr = unsafe {
1363 ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels)
1364 };
1365 if ptr.is_null() {
1366 None
1367 } else {
1368 Some(Self { ptr })
1369 }
1370 }
1371
1372 #[must_use]
1374 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1375 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1376 }
1377
1378 #[must_use]
1380 pub fn memory_weights_are_diagonal(&self) -> bool {
1381 unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1382 }
1383
1384 pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1386 unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1387 }
1388
1389 #[must_use]
1391 pub fn cell_to_output_neuron_type(&self) -> usize {
1392 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1393 }
1394
1395 pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1397 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1398 }
1399
1400 #[must_use]
1402 pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1403 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1404 }
1405
1406 pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1408 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1409 }
1410
1411 #[must_use]
1413 pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1414 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1415 }
1416
1417 pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1419 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1420 }
1421
1422 #[must_use]
1424 pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1425 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1426 }
1427
1428 pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1430 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1431 }
1432}
1433
1434opaque_handle!(RnnRecurrentImageState, "Wraps `MPSRNNRecurrentImageState`.");
1435impl RnnRecurrentImageState {
1436 #[must_use]
1438 pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1439 let ptr = unsafe {
1440 ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index)
1441 };
1442 if ptr.is_null() {
1443 None
1444 } else {
1445 Some(unsafe { Image::from_raw(ptr) })
1446 }
1447 }
1448
1449 #[must_use]
1451 pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1452 let ptr =
1453 unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1454 if ptr.is_null() {
1455 None
1456 } else {
1457 Some(unsafe { Image::from_raw(ptr) })
1458 }
1459 }
1460}
1461
1462opaque_handle!(RnnImageInferenceLayer, "Wraps `MPSRNNImageInferenceLayer`.");
1463impl RnnImageInferenceLayer {
1464 #[must_use]
1466 pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1467 let ptr =
1468 unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1469 if ptr.is_null() {
1470 None
1471 } else {
1472 Some(Self { ptr })
1473 }
1474 }
1475
1476 #[must_use]
1478 pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1479 let handles: Vec<_> = descriptors
1480 .iter()
1481 .map(|descriptor| descriptor.as_ptr())
1482 .collect();
1483 let handles_ptr = if handles.is_empty() {
1484 ptr::null()
1485 } else {
1486 handles.as_ptr()
1487 };
1488 let ptr = unsafe {
1489 ffi::mps_rnn_image_inference_layer_new_stack(
1490 device.as_ptr(),
1491 descriptors.len(),
1492 handles_ptr,
1493 )
1494 };
1495 if ptr.is_null() {
1496 None
1497 } else {
1498 Some(Self { ptr })
1499 }
1500 }
1501
1502 #[must_use]
1504 pub fn input_feature_channels(&self) -> usize {
1505 unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1506 }
1507
1508 #[must_use]
1510 pub fn output_feature_channels(&self) -> usize {
1511 unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1512 }
1513
1514 #[must_use]
1516 pub fn number_of_layers(&self) -> usize {
1517 unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1518 }
1519
1520 #[must_use]
1522 pub fn recurrent_output_is_temporary(&self) -> bool {
1523 unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1524 }
1525
1526 pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1528 unsafe {
1529 ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value);
1530 }
1531 }
1532
1533 #[must_use]
1535 pub fn store_all_intermediate_states(&self) -> bool {
1536 unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1537 }
1538
1539 pub fn set_store_all_intermediate_states(&self, value: bool) {
1541 unsafe {
1542 ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value);
1543 }
1544 }
1545
1546 #[must_use]
1548 pub fn bidirectional_combine_mode(&self) -> usize {
1549 unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1550 }
1551
1552 pub fn set_bidirectional_combine_mode(&self, value: usize) {
1554 unsafe {
1555 ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value);
1556 }
1557 }
1558
1559 #[must_use]
1561 pub fn encode_sequence(
1562 &self,
1563 command_buffer: &CommandBuffer,
1564 source_images: &[&Image],
1565 destination_images: &[&Image],
1566 recurrent_input_state: Option<&RnnRecurrentImageState>,
1567 ) -> Option<RnnRecurrentImageState> {
1568 if source_images.len() != destination_images.len() {
1569 return None;
1570 }
1571 let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1572 let destination_handles: Vec<_> = destination_images
1573 .iter()
1574 .map(|image| image.as_ptr())
1575 .collect();
1576 let source_ptr = if source_handles.is_empty() {
1577 ptr::null()
1578 } else {
1579 source_handles.as_ptr()
1580 };
1581 let destination_ptr = if destination_handles.is_empty() {
1582 ptr::null()
1583 } else {
1584 destination_handles.as_ptr()
1585 };
1586 let recurrent_input_ptr =
1587 recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1588 let ptr = unsafe {
1589 ffi::mps_rnn_image_inference_layer_encode_sequence(
1590 self.ptr,
1591 command_buffer.as_ptr(),
1592 source_images.len(),
1593 source_ptr,
1594 destination_ptr,
1595 recurrent_input_ptr,
1596 )
1597 };
1598 if ptr.is_null() {
1599 None
1600 } else {
1601 Some(RnnRecurrentImageState { ptr })
1602 }
1603 }
1604}