1use crate::ffi;
2use crate::image::Image;
3use crate::matrix::{Matrix, Vector};
4use apple_metal::{CommandBuffer, MetalBuffer, MetalDevice};
5use core::ffi::c_void;
6use core::ptr;
7
8pub mod rnn_sequence_direction {
10 pub const FORWARD: usize = 0;
11 pub const BACKWARD: usize = 1;
12}
13
14pub mod rnn_bidirectional_combine_mode {
16 pub const NONE: usize = 0;
17 pub const ADD: usize = 1;
18 pub const CONCATENATE: usize = 2;
19}
20
21pub mod cnn_convolution_flags {
23 pub const NONE: usize = 0;
24}
25
26pub mod cnn_convolution_weights_layout {
28 pub const OHWI: u32 = 0;
29}
30
31pub mod cnn_accumulator_precision_option {
33 pub const HALF: usize = 0;
34 pub const FLOAT: usize = 1;
35}
36
37pub mod nn_regularization_type {
39 pub const NONE: usize = 0;
40 pub const L1: usize = 1;
41 pub const L2: usize = 2;
42}
43
44macro_rules! opaque_handle {
45 ($name:ident) => {
46 pub struct $name {
47 ptr: *mut c_void,
48 }
49
50 unsafe impl Send for $name {}
51 unsafe impl Sync for $name {}
52
53 impl Drop for $name {
54 fn drop(&mut self) {
55 if !self.ptr.is_null() {
56 unsafe { ffi::mps_object_release(self.ptr) };
57 self.ptr = ptr::null_mut();
58 }
59 }
60 }
61
62 impl $name {
63 #[must_use]
64 pub const fn as_ptr(&self) -> *mut c_void {
65 self.ptr
66 }
67 }
68 };
69}
70
71macro_rules! impl_filter_result_image {
72 ($name:ident) => {
73 impl $name {
74 #[must_use]
75 pub fn result_image(&self) -> Option<NNImageNode> {
76 let ptr = unsafe { ffi::mps_nn_filter_node_result_image(self.ptr) };
77 if ptr.is_null() {
78 None
79 } else {
80 Some(NNImageNode { ptr })
81 }
82 }
83 }
84 };
85}
86
87fn retained_handle(ptr: *mut c_void) -> Option<*mut c_void> {
88 let retained = unsafe { ffi::mps_object_retain(ptr) };
89 if retained.is_null() {
90 None
91 } else {
92 Some(retained)
93 }
94}
95
96macro_rules! impl_rnn_descriptor_common {
97 ($name:ident) => {
98 impl $name {
99 #[must_use]
100 pub fn input_feature_channels(&self) -> usize {
101 unsafe { ffi::mps_rnn_descriptor_input_feature_channels(self.ptr) }
102 }
103
104 pub fn set_input_feature_channels(&self, value: usize) {
105 unsafe { ffi::mps_rnn_descriptor_set_input_feature_channels(self.ptr, value) };
106 }
107
108 #[must_use]
109 pub fn output_feature_channels(&self) -> usize {
110 unsafe { ffi::mps_rnn_descriptor_output_feature_channels(self.ptr) }
111 }
112
113 pub fn set_output_feature_channels(&self, value: usize) {
114 unsafe { ffi::mps_rnn_descriptor_set_output_feature_channels(self.ptr, value) };
115 }
116
117 #[must_use]
118 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
119 unsafe { ffi::mps_rnn_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
120 }
121
122 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
123 unsafe {
124 ffi::mps_rnn_descriptor_set_use_layer_input_unit_transform_mode(self.ptr, value)
125 };
126 }
127
128 #[must_use]
129 pub fn use_float32_weights(&self) -> bool {
130 unsafe { ffi::mps_rnn_descriptor_use_float32_weights(self.ptr) }
131 }
132
133 pub fn set_use_float32_weights(&self, value: bool) {
134 unsafe { ffi::mps_rnn_descriptor_set_use_float32_weights(self.ptr, value) };
135 }
136
137 #[must_use]
138 pub fn layer_sequence_direction(&self) -> usize {
139 unsafe { ffi::mps_rnn_descriptor_layer_sequence_direction(self.ptr) }
140 }
141
142 pub fn set_layer_sequence_direction(&self, value: usize) {
143 unsafe { ffi::mps_rnn_descriptor_set_layer_sequence_direction(self.ptr, value) };
144 }
145 }
146 };
147}
148
149macro_rules! impl_optimizer_common {
150 ($name:ident) => {
151 impl $name {
152 #[must_use]
153 pub fn learning_rate(&self) -> f32 {
154 unsafe { ffi::mps_nn_optimizer_learning_rate(self.ptr) }
155 }
156
157 pub fn set_learning_rate(&self, value: f32) {
158 unsafe { ffi::mps_nn_optimizer_set_learning_rate(self.ptr, value) };
159 }
160
161 #[must_use]
162 pub fn gradient_rescale(&self) -> f32 {
163 unsafe { ffi::mps_nn_optimizer_gradient_rescale(self.ptr) }
164 }
165
166 #[must_use]
167 pub fn apply_gradient_clipping(&self) -> bool {
168 unsafe { ffi::mps_nn_optimizer_apply_gradient_clipping(self.ptr) }
169 }
170
171 pub fn set_apply_gradient_clipping(&self, value: bool) {
172 unsafe { ffi::mps_nn_optimizer_set_apply_gradient_clipping(self.ptr, value) };
173 }
174
175 #[must_use]
176 pub fn gradient_clip_max(&self) -> f32 {
177 unsafe { ffi::mps_nn_optimizer_gradient_clip_max(self.ptr) }
178 }
179
180 #[must_use]
181 pub fn gradient_clip_min(&self) -> f32 {
182 unsafe { ffi::mps_nn_optimizer_gradient_clip_min(self.ptr) }
183 }
184
185 #[must_use]
186 pub fn regularization_scale(&self) -> f32 {
187 unsafe { ffi::mps_nn_optimizer_regularization_scale(self.ptr) }
188 }
189
190 #[must_use]
191 pub fn regularization_type(&self) -> usize {
192 unsafe { ffi::mps_nn_optimizer_regularization_type(self.ptr) }
193 }
194 }
195 };
196}
197
198opaque_handle!(NNImageNode);
199impl NNImageNode {
200 #[must_use]
201 pub fn new() -> Option<Self> {
202 let ptr = unsafe { ffi::mps_nn_image_node_new() };
203 if ptr.is_null() {
204 None
205 } else {
206 Some(Self { ptr })
207 }
208 }
209
210 #[must_use]
211 pub fn exported() -> Option<Self> {
212 let ptr = unsafe { ffi::mps_nn_image_node_exported() };
213 if ptr.is_null() {
214 None
215 } else {
216 Some(Self { ptr })
217 }
218 }
219
220 #[must_use]
221 pub fn format(&self) -> usize {
222 unsafe { ffi::mps_nn_image_node_format(self.ptr) }
223 }
224
225 pub fn set_format(&self, format: usize) {
226 unsafe { ffi::mps_nn_image_node_set_format(self.ptr, format) };
227 }
228
229 #[must_use]
230 pub fn export_from_graph(&self) -> bool {
231 unsafe { ffi::mps_nn_image_node_export_from_graph(self.ptr) }
232 }
233
234 pub fn set_export_from_graph(&self, export: bool) {
235 unsafe { ffi::mps_nn_image_node_set_export_from_graph(self.ptr, export) };
236 }
237
238 #[must_use]
239 pub fn synchronize_resource(&self) -> bool {
240 unsafe { ffi::mps_nn_image_node_synchronize_resource(self.ptr) }
241 }
242
243 pub fn set_synchronize_resource(&self, synchronize: bool) {
244 unsafe { ffi::mps_nn_image_node_set_synchronize_resource(self.ptr, synchronize) };
245 }
246
247 pub fn use_default_allocator(&self) {
248 unsafe { ffi::mps_nn_image_node_use_default_allocator(self.ptr) };
249 }
250}
251
252opaque_handle!(CnnNeuronReluNode);
253impl CnnNeuronReluNode {
254 #[must_use]
255 pub fn new(source: &NNImageNode, a: f32) -> Option<Self> {
256 let ptr = unsafe { ffi::mps_cnn_neuron_relu_node_new(source.as_ptr(), a) };
257 if ptr.is_null() {
258 None
259 } else {
260 Some(Self { ptr })
261 }
262 }
263}
264impl_filter_result_image!(CnnNeuronReluNode);
265
266opaque_handle!(CnnPoolingMaxNode);
267impl CnnPoolingMaxNode {
268 #[must_use]
269 pub fn new(source: &NNImageNode, filter_size: usize, stride: usize) -> Option<Self> {
270 let ptr =
271 unsafe { ffi::mps_cnn_pooling_max_node_new(source.as_ptr(), filter_size, stride) };
272 if ptr.is_null() {
273 None
274 } else {
275 Some(Self { ptr })
276 }
277 }
278}
279impl_filter_result_image!(CnnPoolingMaxNode);
280
281opaque_handle!(CnnSoftMaxNode);
282impl CnnSoftMaxNode {
283 #[must_use]
284 pub fn new(source: &NNImageNode) -> Option<Self> {
285 let ptr = unsafe { ffi::mps_cnn_softmax_node_new(source.as_ptr()) };
286 if ptr.is_null() {
287 None
288 } else {
289 Some(Self { ptr })
290 }
291 }
292}
293impl_filter_result_image!(CnnSoftMaxNode);
294
295opaque_handle!(CnnUpsamplingNearestNode);
296impl CnnUpsamplingNearestNode {
297 #[must_use]
298 pub fn new(source: &NNImageNode, scale_x: usize, scale_y: usize) -> Option<Self> {
299 let ptr =
300 unsafe { ffi::mps_cnn_upsampling_nearest_node_new(source.as_ptr(), scale_x, scale_y) };
301 if ptr.is_null() {
302 None
303 } else {
304 Some(Self { ptr })
305 }
306 }
307}
308impl_filter_result_image!(CnnUpsamplingNearestNode);
309
310opaque_handle!(NNGraph);
311impl NNGraph {
312 #[must_use]
313 pub fn new(
314 device: &MetalDevice,
315 result_image: &NNImageNode,
316 result_image_is_needed: bool,
317 ) -> Option<Self> {
318 let ptr = unsafe {
319 ffi::mps_nn_graph_new(
320 device.as_ptr(),
321 result_image.as_ptr(),
322 result_image_is_needed,
323 )
324 };
325 if ptr.is_null() {
326 None
327 } else {
328 Some(Self { ptr })
329 }
330 }
331
332 #[must_use]
333 pub fn source_image_count(&self) -> usize {
334 unsafe { ffi::mps_nn_graph_source_image_count(self.ptr) }
335 }
336
337 #[must_use]
338 pub fn format(&self) -> usize {
339 unsafe { ffi::mps_nn_graph_format(self.ptr) }
340 }
341
342 pub fn set_format(&self, format: usize) {
343 unsafe { ffi::mps_nn_graph_set_format(self.ptr, format) };
344 }
345
346 pub fn set_output_state_is_temporary(&self, temporary: bool) {
347 unsafe { ffi::mps_nn_graph_set_output_state_is_temporary(self.ptr, temporary) };
348 }
349
350 pub fn use_default_destination_image_allocator(&self) {
351 unsafe { ffi::mps_nn_graph_use_default_destination_image_allocator(self.ptr) };
352 }
353
354 pub fn reload_from_data_sources(&self) {
355 unsafe { ffi::mps_nn_graph_reload_from_data_sources(self.ptr) };
356 }
357
358 #[must_use]
359 pub fn encode(
360 &self,
361 command_buffer: &CommandBuffer,
362 source_images: &[&Image],
363 ) -> Option<Image> {
364 let handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
365 let source_handles = if handles.is_empty() {
366 ptr::null()
367 } else {
368 handles.as_ptr()
369 };
370 let ptr = unsafe {
371 ffi::mps_nn_graph_encode(
372 self.ptr,
373 command_buffer.as_ptr(),
374 source_images.len(),
375 source_handles,
376 )
377 };
378 if ptr.is_null() {
379 None
380 } else {
381 Some(unsafe { Image::from_raw(ptr) })
382 }
383 }
384}
385
386opaque_handle!(CnnConvolutionDescriptor);
387impl CnnConvolutionDescriptor {
388 #[must_use]
389 pub fn new(
390 kernel_width: usize,
391 kernel_height: usize,
392 input_feature_channels: usize,
393 output_feature_channels: usize,
394 ) -> Option<Self> {
395 let ptr = unsafe {
396 ffi::mps_cnn_convolution_descriptor_new(
397 kernel_width,
398 kernel_height,
399 input_feature_channels,
400 output_feature_channels,
401 )
402 };
403 if ptr.is_null() {
404 None
405 } else {
406 Some(Self { ptr })
407 }
408 }
409
410 #[must_use]
411 pub fn kernel_width(&self) -> usize {
412 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_width(self.ptr) }
413 }
414
415 #[must_use]
416 pub fn kernel_height(&self) -> usize {
417 unsafe { ffi::mps_cnn_convolution_descriptor_kernel_height(self.ptr) }
418 }
419
420 #[must_use]
421 pub fn input_feature_channels(&self) -> usize {
422 unsafe { ffi::mps_cnn_convolution_descriptor_input_feature_channels(self.ptr) }
423 }
424
425 #[must_use]
426 pub fn output_feature_channels(&self) -> usize {
427 unsafe { ffi::mps_cnn_convolution_descriptor_output_feature_channels(self.ptr) }
428 }
429
430 #[must_use]
431 pub fn stride_in_pixels_x(&self) -> usize {
432 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_x(self.ptr) }
433 }
434
435 pub fn set_stride_in_pixels_x(&self, value: usize) {
436 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_x(self.ptr, value) };
437 }
438
439 #[must_use]
440 pub fn stride_in_pixels_y(&self) -> usize {
441 unsafe { ffi::mps_cnn_convolution_descriptor_stride_in_pixels_y(self.ptr) }
442 }
443
444 pub fn set_stride_in_pixels_y(&self, value: usize) {
445 unsafe { ffi::mps_cnn_convolution_descriptor_set_stride_in_pixels_y(self.ptr, value) };
446 }
447
448 #[must_use]
449 pub fn groups(&self) -> usize {
450 unsafe { ffi::mps_cnn_convolution_descriptor_groups(self.ptr) }
451 }
452
453 pub fn set_groups(&self, value: usize) {
454 unsafe { ffi::mps_cnn_convolution_descriptor_set_groups(self.ptr, value) };
455 }
456
457 #[must_use]
458 pub fn dilation_rate_x(&self) -> usize {
459 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_x(self.ptr) }
460 }
461
462 pub fn set_dilation_rate_x(&self, value: usize) {
463 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_x(self.ptr, value) };
464 }
465
466 #[must_use]
467 pub fn dilation_rate_y(&self) -> usize {
468 unsafe { ffi::mps_cnn_convolution_descriptor_dilation_rate_y(self.ptr) }
469 }
470
471 pub fn set_dilation_rate_y(&self, value: usize) {
472 unsafe { ffi::mps_cnn_convolution_descriptor_set_dilation_rate_y(self.ptr, value) };
473 }
474}
475
476opaque_handle!(RnnSingleGateDescriptor);
477impl RnnSingleGateDescriptor {
478 #[must_use]
479 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
480 let ptr = unsafe {
481 ffi::mps_rnn_single_gate_descriptor_new(input_feature_channels, output_feature_channels)
482 };
483 if ptr.is_null() {
484 None
485 } else {
486 Some(Self { ptr })
487 }
488 }
489
490 #[must_use]
491 pub fn input_feature_channels(&self) -> usize {
492 unsafe { ffi::mps_rnn_single_gate_descriptor_input_feature_channels(self.ptr) }
493 }
494
495 pub fn set_input_feature_channels(&self, value: usize) {
496 unsafe { ffi::mps_rnn_single_gate_descriptor_set_input_feature_channels(self.ptr, value) };
497 }
498
499 #[must_use]
500 pub fn output_feature_channels(&self) -> usize {
501 unsafe { ffi::mps_rnn_single_gate_descriptor_output_feature_channels(self.ptr) }
502 }
503
504 pub fn set_output_feature_channels(&self, value: usize) {
505 unsafe { ffi::mps_rnn_single_gate_descriptor_set_output_feature_channels(self.ptr, value) };
506 }
507
508 #[must_use]
509 pub fn use_layer_input_unit_transform_mode(&self) -> bool {
510 unsafe { ffi::mps_rnn_single_gate_descriptor_use_layer_input_unit_transform_mode(self.ptr) }
511 }
512
513 pub fn set_use_layer_input_unit_transform_mode(&self, value: bool) {
514 unsafe {
515 ffi::mps_rnn_single_gate_descriptor_set_use_layer_input_unit_transform_mode(
516 self.ptr, value,
517 );
518 };
519 }
520
521 #[must_use]
522 pub fn use_float32_weights(&self) -> bool {
523 unsafe { ffi::mps_rnn_single_gate_descriptor_use_float32_weights(self.ptr) }
524 }
525
526 pub fn set_use_float32_weights(&self, value: bool) {
527 unsafe { ffi::mps_rnn_single_gate_descriptor_set_use_float32_weights(self.ptr, value) };
528 }
529
530 #[must_use]
531 pub fn layer_sequence_direction(&self) -> usize {
532 unsafe { ffi::mps_rnn_single_gate_descriptor_layer_sequence_direction(self.ptr) }
533 }
534
535 pub fn set_layer_sequence_direction(&self, value: usize) {
536 unsafe {
537 ffi::mps_rnn_single_gate_descriptor_set_layer_sequence_direction(self.ptr, value);
538 };
539 }
540
541 #[must_use]
542 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
543 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
544 }
545}
546
547opaque_handle!(CnnConvolution);
548impl CnnConvolution {
549 #[must_use]
550 pub fn new(
551 device: &MetalDevice,
552 descriptor: &CnnConvolutionDescriptor,
553 kernel_weights: &[f32],
554 bias_terms: Option<&[f32]>,
555 flags: usize,
556 ) -> Option<Self> {
557 if kernel_weights.is_empty() {
558 return None;
559 }
560 let bias_terms_ptr = bias_terms.map_or(ptr::null(), <[f32]>::as_ptr);
561 let ptr = unsafe {
562 ffi::mps_cnn_convolution_new(
563 device.as_ptr(),
564 descriptor.as_ptr(),
565 kernel_weights.as_ptr(),
566 bias_terms_ptr,
567 flags,
568 )
569 };
570 if ptr.is_null() {
571 None
572 } else {
573 Some(Self { ptr })
574 }
575 }
576
577 #[must_use]
578 pub fn input_feature_channels(&self) -> usize {
579 unsafe { ffi::mps_cnn_convolution_input_feature_channels(self.ptr) }
580 }
581
582 #[must_use]
583 pub fn output_feature_channels(&self) -> usize {
584 unsafe { ffi::mps_cnn_convolution_output_feature_channels(self.ptr) }
585 }
586
587 #[must_use]
588 pub fn groups(&self) -> usize {
589 unsafe { ffi::mps_cnn_convolution_groups(self.ptr) }
590 }
591
592 #[must_use]
593 pub fn sub_pixel_scale_factor(&self) -> usize {
594 unsafe { ffi::mps_cnn_convolution_sub_pixel_scale_factor(self.ptr) }
595 }
596
597 #[must_use]
598 pub fn channel_multiplier(&self) -> usize {
599 unsafe { ffi::mps_cnn_convolution_channel_multiplier(self.ptr) }
600 }
601
602 #[must_use]
603 pub fn accumulator_precision_option(&self) -> usize {
604 unsafe { ffi::mps_cnn_convolution_accumulator_precision_option(self.ptr) }
605 }
606
607 pub fn set_accumulator_precision_option(&self, value: usize) {
608 unsafe { ffi::mps_cnn_convolution_set_accumulator_precision_option(self.ptr, value) };
609 }
610
611 pub fn encode_image(&self, command_buffer: &CommandBuffer, source: &Image, destination: &Image) {
612 unsafe {
613 ffi::mps_cnn_convolution_encode_image(
614 self.ptr,
615 command_buffer.as_ptr(),
616 source.as_ptr(),
617 destination.as_ptr(),
618 );
619 };
620 }
621}
622
623opaque_handle!(CnnConvolutionWeightsAndBiasesState);
624impl CnnConvolutionWeightsAndBiasesState {
625 #[must_use]
626 pub fn new_with_buffers(weights: &MetalBuffer, biases: Option<&MetalBuffer>) -> Option<Self> {
627 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
628 let ptr = unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_new(weights.as_ptr(), biases_ptr) };
629 if ptr.is_null() {
630 None
631 } else {
632 Some(Self { ptr })
633 }
634 }
635
636 #[must_use]
637 pub fn new_with_offsets(
638 weights: &MetalBuffer,
639 weights_offset: usize,
640 biases: Option<&MetalBuffer>,
641 biases_offset: usize,
642 descriptor: &CnnConvolutionDescriptor,
643 ) -> Option<Self> {
644 let biases_ptr = biases.map_or(ptr::null_mut(), MetalBuffer::as_ptr);
645 let ptr = unsafe {
646 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_offsets(
647 weights.as_ptr(),
648 weights_offset,
649 biases_ptr,
650 biases_offset,
651 descriptor.as_ptr(),
652 )
653 };
654 if ptr.is_null() {
655 None
656 } else {
657 Some(Self { ptr })
658 }
659 }
660
661 #[must_use]
662 pub fn new_with_device(
663 device: &MetalDevice,
664 descriptor: &CnnConvolutionDescriptor,
665 ) -> Option<Self> {
666 let ptr = unsafe {
667 ffi::mps_cnn_convolution_weights_and_biases_state_new_with_device(
668 device.as_ptr(),
669 descriptor.as_ptr(),
670 )
671 };
672 if ptr.is_null() {
673 None
674 } else {
675 Some(Self { ptr })
676 }
677 }
678
679 #[must_use]
680 pub fn weights_offset(&self) -> usize {
681 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_weights_offset(self.ptr) }
682 }
683
684 #[must_use]
685 pub fn biases_offset(&self) -> usize {
686 unsafe { ffi::mps_cnn_convolution_weights_and_biases_state_biases_offset(self.ptr) }
687 }
688}
689
690opaque_handle!(NNOptimizerDescriptor);
691impl NNOptimizerDescriptor {
692 #[must_use]
693 pub fn new(
694 learning_rate: f32,
695 gradient_rescale: f32,
696 regularization_type: usize,
697 regularization_scale: f32,
698 ) -> Option<Self> {
699 let ptr = unsafe {
700 ffi::mps_nn_optimizer_descriptor_new(
701 learning_rate,
702 gradient_rescale,
703 regularization_type,
704 regularization_scale,
705 )
706 };
707 if ptr.is_null() {
708 None
709 } else {
710 Some(Self { ptr })
711 }
712 }
713
714 #[must_use]
715 pub fn with_gradient_clipping(
716 learning_rate: f32,
717 gradient_rescale: f32,
718 apply_gradient_clipping: bool,
719 gradient_clip_max: f32,
720 gradient_clip_min: f32,
721 regularization_type: usize,
722 regularization_scale: f32,
723 ) -> Option<Self> {
724 let ptr = unsafe {
725 ffi::mps_nn_optimizer_descriptor_new_with_gradient_clipping(
726 learning_rate,
727 gradient_rescale,
728 apply_gradient_clipping,
729 gradient_clip_max,
730 gradient_clip_min,
731 regularization_type,
732 regularization_scale,
733 )
734 };
735 if ptr.is_null() {
736 None
737 } else {
738 Some(Self { ptr })
739 }
740 }
741
742 #[must_use]
743 pub fn learning_rate(&self) -> f32 {
744 unsafe { ffi::mps_nn_optimizer_descriptor_learning_rate(self.ptr) }
745 }
746
747 pub fn set_learning_rate(&self, value: f32) {
748 unsafe { ffi::mps_nn_optimizer_descriptor_set_learning_rate(self.ptr, value) };
749 }
750
751 #[must_use]
752 pub fn gradient_rescale(&self) -> f32 {
753 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_rescale(self.ptr) }
754 }
755
756 pub fn set_gradient_rescale(&self, value: f32) {
757 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_rescale(self.ptr, value) };
758 }
759
760 #[must_use]
761 pub fn apply_gradient_clipping(&self) -> bool {
762 unsafe { ffi::mps_nn_optimizer_descriptor_apply_gradient_clipping(self.ptr) }
763 }
764
765 pub fn set_apply_gradient_clipping(&self, value: bool) {
766 unsafe { ffi::mps_nn_optimizer_descriptor_set_apply_gradient_clipping(self.ptr, value) };
767 }
768
769 #[must_use]
770 pub fn gradient_clip_max(&self) -> f32 {
771 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_max(self.ptr) }
772 }
773
774 pub fn set_gradient_clip_max(&self, value: f32) {
775 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_max(self.ptr, value) };
776 }
777
778 #[must_use]
779 pub fn gradient_clip_min(&self) -> f32 {
780 unsafe { ffi::mps_nn_optimizer_descriptor_gradient_clip_min(self.ptr) }
781 }
782
783 pub fn set_gradient_clip_min(&self, value: f32) {
784 unsafe { ffi::mps_nn_optimizer_descriptor_set_gradient_clip_min(self.ptr, value) };
785 }
786
787 #[must_use]
788 pub fn regularization_scale(&self) -> f32 {
789 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_scale(self.ptr) }
790 }
791
792 pub fn set_regularization_scale(&self, value: f32) {
793 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_scale(self.ptr, value) };
794 }
795
796 #[must_use]
797 pub fn regularization_type(&self) -> usize {
798 unsafe { ffi::mps_nn_optimizer_descriptor_regularization_type(self.ptr) }
799 }
800
801 pub fn set_regularization_type(&self, value: usize) {
802 unsafe { ffi::mps_nn_optimizer_descriptor_set_regularization_type(self.ptr, value) };
803 }
804}
805
806opaque_handle!(NNOptimizer);
807impl_optimizer_common!(NNOptimizer);
808
809opaque_handle!(NNOptimizerStochasticGradientDescent);
810impl_optimizer_common!(NNOptimizerStochasticGradientDescent);
811impl NNOptimizerStochasticGradientDescent {
812 #[must_use]
813 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
814 let ptr = unsafe { ffi::mps_nn_optimizer_sgd_new(device.as_ptr(), learning_rate) };
815 if ptr.is_null() {
816 None
817 } else {
818 Some(Self { ptr })
819 }
820 }
821
822 #[must_use]
823 pub fn new_with_options(
824 device: &MetalDevice,
825 momentum_scale: f32,
826 use_nesterov_momentum: bool,
827 optimizer_descriptor: &NNOptimizerDescriptor,
828 ) -> Option<Self> {
829 let ptr = unsafe {
830 ffi::mps_nn_optimizer_sgd_new_with_options(
831 device.as_ptr(),
832 momentum_scale,
833 use_nesterov_momentum,
834 optimizer_descriptor.as_ptr(),
835 )
836 };
837 if ptr.is_null() {
838 None
839 } else {
840 Some(Self { ptr })
841 }
842 }
843
844 #[must_use]
845 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
846 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
847 }
848
849 #[must_use]
850 pub fn momentum_scale(&self) -> f32 {
851 unsafe { ffi::mps_nn_optimizer_sgd_momentum_scale(self.ptr) }
852 }
853
854 #[must_use]
855 pub fn use_nesterov_momentum(&self) -> bool {
856 unsafe { ffi::mps_nn_optimizer_sgd_use_nesterov_momentum(self.ptr) }
857 }
858
859 pub fn encode_vector(
860 &self,
861 command_buffer: &CommandBuffer,
862 input_gradient_vector: &Vector,
863 input_values_vector: &Vector,
864 input_momentum_vector: Option<&Vector>,
865 result_values_vector: &Vector,
866 ) {
867 let input_momentum_ptr = input_momentum_vector.map_or(ptr::null_mut(), Vector::as_ptr);
868 unsafe {
869 ffi::mps_nn_optimizer_sgd_encode_vector(
870 self.ptr,
871 command_buffer.as_ptr(),
872 input_gradient_vector.as_ptr(),
873 input_values_vector.as_ptr(),
874 input_momentum_ptr,
875 result_values_vector.as_ptr(),
876 );
877 };
878 }
879
880 pub fn encode_matrix(
881 &self,
882 command_buffer: &CommandBuffer,
883 input_gradient_matrix: &Matrix,
884 input_values_matrix: &Matrix,
885 input_momentum_matrix: Option<&Matrix>,
886 result_values_matrix: &Matrix,
887 ) {
888 let input_momentum_ptr = input_momentum_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
889 unsafe {
890 ffi::mps_nn_optimizer_sgd_encode_matrix(
891 self.ptr,
892 command_buffer.as_ptr(),
893 input_gradient_matrix.as_ptr(),
894 input_values_matrix.as_ptr(),
895 input_momentum_ptr,
896 result_values_matrix.as_ptr(),
897 );
898 };
899 }
900}
901
902opaque_handle!(NNOptimizerRmsProp);
903impl_optimizer_common!(NNOptimizerRmsProp);
904impl NNOptimizerRmsProp {
905 #[must_use]
906 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
907 let ptr = unsafe { ffi::mps_nn_optimizer_rmsprop_new(device.as_ptr(), learning_rate) };
908 if ptr.is_null() {
909 None
910 } else {
911 Some(Self { ptr })
912 }
913 }
914
915 #[must_use]
916 pub fn new_with_options(
917 device: &MetalDevice,
918 decay: f64,
919 epsilon: f32,
920 optimizer_descriptor: &NNOptimizerDescriptor,
921 ) -> Option<Self> {
922 let ptr = unsafe {
923 ffi::mps_nn_optimizer_rmsprop_new_with_options(
924 device.as_ptr(),
925 decay,
926 epsilon,
927 optimizer_descriptor.as_ptr(),
928 )
929 };
930 if ptr.is_null() {
931 None
932 } else {
933 Some(Self { ptr })
934 }
935 }
936
937 #[must_use]
938 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
939 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
940 }
941
942 #[must_use]
943 pub fn decay(&self) -> f64 {
944 unsafe { ffi::mps_nn_optimizer_rmsprop_decay(self.ptr) }
945 }
946
947 #[must_use]
948 pub fn epsilon(&self) -> f32 {
949 unsafe { ffi::mps_nn_optimizer_rmsprop_epsilon(self.ptr) }
950 }
951
952 pub fn encode_vector(
953 &self,
954 command_buffer: &CommandBuffer,
955 input_gradient_vector: &Vector,
956 input_values_vector: &Vector,
957 input_sum_of_squares_vector: &Vector,
958 result_values_vector: &Vector,
959 ) {
960 unsafe {
961 ffi::mps_nn_optimizer_rmsprop_encode_vector(
962 self.ptr,
963 command_buffer.as_ptr(),
964 input_gradient_vector.as_ptr(),
965 input_values_vector.as_ptr(),
966 input_sum_of_squares_vector.as_ptr(),
967 result_values_vector.as_ptr(),
968 );
969 };
970 }
971
972 pub fn encode_matrix(
973 &self,
974 command_buffer: &CommandBuffer,
975 input_gradient_matrix: &Matrix,
976 input_values_matrix: &Matrix,
977 input_sum_of_squares_matrix: &Matrix,
978 result_values_matrix: &Matrix,
979 ) {
980 unsafe {
981 ffi::mps_nn_optimizer_rmsprop_encode_matrix(
982 self.ptr,
983 command_buffer.as_ptr(),
984 input_gradient_matrix.as_ptr(),
985 input_values_matrix.as_ptr(),
986 input_sum_of_squares_matrix.as_ptr(),
987 result_values_matrix.as_ptr(),
988 );
989 };
990 }
991}
992
993opaque_handle!(NNOptimizerAdam);
994impl_optimizer_common!(NNOptimizerAdam);
995impl NNOptimizerAdam {
996 #[must_use]
997 pub fn new(device: &MetalDevice, learning_rate: f32) -> Option<Self> {
998 let ptr = unsafe { ffi::mps_nn_optimizer_adam_new(device.as_ptr(), learning_rate) };
999 if ptr.is_null() {
1000 None
1001 } else {
1002 Some(Self { ptr })
1003 }
1004 }
1005
1006 #[must_use]
1007 pub fn new_with_options(
1008 device: &MetalDevice,
1009 beta1: f64,
1010 beta2: f64,
1011 epsilon: f32,
1012 time_step: usize,
1013 optimizer_descriptor: &NNOptimizerDescriptor,
1014 ) -> Option<Self> {
1015 let ptr = unsafe {
1016 ffi::mps_nn_optimizer_adam_new_with_options(
1017 device.as_ptr(),
1018 beta1,
1019 beta2,
1020 epsilon,
1021 time_step,
1022 optimizer_descriptor.as_ptr(),
1023 )
1024 };
1025 if ptr.is_null() {
1026 None
1027 } else {
1028 Some(Self { ptr })
1029 }
1030 }
1031
1032 #[must_use]
1033 pub fn as_optimizer(&self) -> Option<NNOptimizer> {
1034 retained_handle(self.ptr).map(|ptr| NNOptimizer { ptr })
1035 }
1036
1037 #[must_use]
1038 pub fn beta1(&self) -> f64 {
1039 unsafe { ffi::mps_nn_optimizer_adam_beta1(self.ptr) }
1040 }
1041
1042 #[must_use]
1043 pub fn beta2(&self) -> f64 {
1044 unsafe { ffi::mps_nn_optimizer_adam_beta2(self.ptr) }
1045 }
1046
1047 #[must_use]
1048 pub fn epsilon(&self) -> f32 {
1049 unsafe { ffi::mps_nn_optimizer_adam_epsilon(self.ptr) }
1050 }
1051
1052 #[must_use]
1053 pub fn time_step(&self) -> usize {
1054 unsafe { ffi::mps_nn_optimizer_adam_time_step(self.ptr) }
1055 }
1056
1057 pub fn set_time_step(&self, value: usize) {
1058 unsafe { ffi::mps_nn_optimizer_adam_set_time_step(self.ptr, value) };
1059 }
1060
1061 pub fn encode_vector(
1062 &self,
1063 command_buffer: &CommandBuffer,
1064 input_gradient_vector: &Vector,
1065 input_values_vector: &Vector,
1066 input_momentum_vector: &Vector,
1067 input_velocity_vector: &Vector,
1068 result_values_vector: &Vector,
1069 ) {
1070 unsafe {
1071 ffi::mps_nn_optimizer_adam_encode_vector(
1072 self.ptr,
1073 command_buffer.as_ptr(),
1074 input_gradient_vector.as_ptr(),
1075 input_values_vector.as_ptr(),
1076 input_momentum_vector.as_ptr(),
1077 input_velocity_vector.as_ptr(),
1078 result_values_vector.as_ptr(),
1079 );
1080 };
1081 }
1082
1083 pub fn encode_matrix(
1084 &self,
1085 command_buffer: &CommandBuffer,
1086 input_gradient_matrix: &Matrix,
1087 input_values_matrix: &Matrix,
1088 input_momentum_matrix: &Matrix,
1089 input_velocity_matrix: &Matrix,
1090 result_values_matrix: &Matrix,
1091 ) {
1092 unsafe {
1093 ffi::mps_nn_optimizer_adam_encode_matrix(
1094 self.ptr,
1095 command_buffer.as_ptr(),
1096 input_gradient_matrix.as_ptr(),
1097 input_values_matrix.as_ptr(),
1098 input_momentum_matrix.as_ptr(),
1099 input_velocity_matrix.as_ptr(),
1100 result_values_matrix.as_ptr(),
1101 );
1102 };
1103 }
1104
1105 #[allow(clippy::too_many_arguments)]
1106 pub fn encode_amsgrad_vector(
1107 &self,
1108 command_buffer: &CommandBuffer,
1109 input_gradient_vector: &Vector,
1110 input_values_vector: &Vector,
1111 input_momentum_vector: &Vector,
1112 input_velocity_vector: &Vector,
1113 maximum_velocity_vector: Option<&Vector>,
1114 result_values_vector: &Vector,
1115 ) {
1116 let maximum_velocity_ptr = maximum_velocity_vector.map_or(ptr::null_mut(), Vector::as_ptr);
1117 unsafe {
1118 ffi::mps_nn_optimizer_adam_encode_amsgrad_vector(
1119 self.ptr,
1120 command_buffer.as_ptr(),
1121 input_gradient_vector.as_ptr(),
1122 input_values_vector.as_ptr(),
1123 input_momentum_vector.as_ptr(),
1124 input_velocity_vector.as_ptr(),
1125 maximum_velocity_ptr,
1126 result_values_vector.as_ptr(),
1127 );
1128 };
1129 }
1130
1131 #[allow(clippy::too_many_arguments)]
1132 pub fn encode_amsgrad_matrix(
1133 &self,
1134 command_buffer: &CommandBuffer,
1135 input_gradient_matrix: &Matrix,
1136 input_values_matrix: &Matrix,
1137 input_momentum_matrix: &Matrix,
1138 input_velocity_matrix: &Matrix,
1139 maximum_velocity_matrix: Option<&Matrix>,
1140 result_values_matrix: &Matrix,
1141 ) {
1142 let maximum_velocity_ptr = maximum_velocity_matrix.map_or(ptr::null_mut(), Matrix::as_ptr);
1143 unsafe {
1144 ffi::mps_nn_optimizer_adam_encode_amsgrad_matrix(
1145 self.ptr,
1146 command_buffer.as_ptr(),
1147 input_gradient_matrix.as_ptr(),
1148 input_values_matrix.as_ptr(),
1149 input_momentum_matrix.as_ptr(),
1150 input_velocity_matrix.as_ptr(),
1151 maximum_velocity_ptr,
1152 result_values_matrix.as_ptr(),
1153 );
1154 };
1155 }
1156}
1157
1158opaque_handle!(RnnDescriptor);
1159impl_rnn_descriptor_common!(RnnDescriptor);
1160
1161opaque_handle!(GruDescriptor);
1162impl_rnn_descriptor_common!(GruDescriptor);
1163impl GruDescriptor {
1164 #[must_use]
1165 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1166 let ptr = unsafe { ffi::mps_gru_descriptor_new(input_feature_channels, output_feature_channels) };
1167 if ptr.is_null() {
1168 None
1169 } else {
1170 Some(Self { ptr })
1171 }
1172 }
1173
1174 #[must_use]
1175 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1176 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1177 }
1178
1179 #[must_use]
1180 pub fn gate_pnorm_value(&self) -> f32 {
1181 unsafe { ffi::mps_gru_descriptor_gate_pnorm_value(self.ptr) }
1182 }
1183
1184 pub fn set_gate_pnorm_value(&self, value: f32) {
1185 unsafe { ffi::mps_gru_descriptor_set_gate_pnorm_value(self.ptr, value) };
1186 }
1187
1188 #[must_use]
1189 pub fn flip_output_gates(&self) -> bool {
1190 unsafe { ffi::mps_gru_descriptor_flip_output_gates(self.ptr) }
1191 }
1192
1193 pub fn set_flip_output_gates(&self, value: bool) {
1194 unsafe { ffi::mps_gru_descriptor_set_flip_output_gates(self.ptr, value) };
1195 }
1196}
1197
1198opaque_handle!(LstmDescriptor);
1199impl_rnn_descriptor_common!(LstmDescriptor);
1200impl LstmDescriptor {
1201 #[must_use]
1202 pub fn new(input_feature_channels: usize, output_feature_channels: usize) -> Option<Self> {
1203 let ptr = unsafe { ffi::mps_lstm_descriptor_new(input_feature_channels, output_feature_channels) };
1204 if ptr.is_null() {
1205 None
1206 } else {
1207 Some(Self { ptr })
1208 }
1209 }
1210
1211 #[must_use]
1212 pub fn as_descriptor(&self) -> Option<RnnDescriptor> {
1213 retained_handle(self.ptr).map(|ptr| RnnDescriptor { ptr })
1214 }
1215
1216 #[must_use]
1217 pub fn memory_weights_are_diagonal(&self) -> bool {
1218 unsafe { ffi::mps_lstm_descriptor_memory_weights_are_diagonal(self.ptr) }
1219 }
1220
1221 pub fn set_memory_weights_are_diagonal(&self, value: bool) {
1222 unsafe { ffi::mps_lstm_descriptor_set_memory_weights_are_diagonal(self.ptr, value) };
1223 }
1224
1225 #[must_use]
1226 pub fn cell_to_output_neuron_type(&self) -> usize {
1227 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_type(self.ptr) }
1228 }
1229
1230 pub fn set_cell_to_output_neuron_type(&self, value: usize) {
1231 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_type(self.ptr, value) };
1232 }
1233
1234 #[must_use]
1235 pub fn cell_to_output_neuron_param_a(&self) -> f32 {
1236 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_a(self.ptr) }
1237 }
1238
1239 pub fn set_cell_to_output_neuron_param_a(&self, value: f32) {
1240 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_a(self.ptr, value) };
1241 }
1242
1243 #[must_use]
1244 pub fn cell_to_output_neuron_param_b(&self) -> f32 {
1245 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_b(self.ptr) }
1246 }
1247
1248 pub fn set_cell_to_output_neuron_param_b(&self, value: f32) {
1249 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_b(self.ptr, value) };
1250 }
1251
1252 #[must_use]
1253 pub fn cell_to_output_neuron_param_c(&self) -> f32 {
1254 unsafe { ffi::mps_lstm_descriptor_cell_to_output_neuron_param_c(self.ptr) }
1255 }
1256
1257 pub fn set_cell_to_output_neuron_param_c(&self, value: f32) {
1258 unsafe { ffi::mps_lstm_descriptor_set_cell_to_output_neuron_param_c(self.ptr, value) };
1259 }
1260}
1261
1262opaque_handle!(RnnRecurrentImageState);
1263impl RnnRecurrentImageState {
1264 #[must_use]
1265 pub fn recurrent_output_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1266 let ptr = unsafe { ffi::mps_rnn_recurrent_image_state_recurrent_output_image(self.ptr, layer_index) };
1267 if ptr.is_null() {
1268 None
1269 } else {
1270 Some(unsafe { Image::from_raw(ptr) })
1271 }
1272 }
1273
1274 #[must_use]
1275 pub fn memory_cell_image_for_layer_index(&self, layer_index: usize) -> Option<Image> {
1276 let ptr = unsafe { ffi::mps_rnn_recurrent_image_state_memory_cell_image(self.ptr, layer_index) };
1277 if ptr.is_null() {
1278 None
1279 } else {
1280 Some(unsafe { Image::from_raw(ptr) })
1281 }
1282 }
1283}
1284
1285opaque_handle!(RnnImageInferenceLayer);
1286impl RnnImageInferenceLayer {
1287 #[must_use]
1288 pub fn new(device: &MetalDevice, descriptor: &RnnDescriptor) -> Option<Self> {
1289 let ptr = unsafe { ffi::mps_rnn_image_inference_layer_new(device.as_ptr(), descriptor.as_ptr()) };
1290 if ptr.is_null() {
1291 None
1292 } else {
1293 Some(Self { ptr })
1294 }
1295 }
1296
1297 #[must_use]
1298 pub fn new_stack(device: &MetalDevice, descriptors: &[&RnnDescriptor]) -> Option<Self> {
1299 let handles: Vec<_> = descriptors.iter().map(|descriptor| descriptor.as_ptr()).collect();
1300 let handles_ptr = if handles.is_empty() {
1301 ptr::null()
1302 } else {
1303 handles.as_ptr()
1304 };
1305 let ptr = unsafe {
1306 ffi::mps_rnn_image_inference_layer_new_stack(
1307 device.as_ptr(),
1308 descriptors.len(),
1309 handles_ptr,
1310 )
1311 };
1312 if ptr.is_null() {
1313 None
1314 } else {
1315 Some(Self { ptr })
1316 }
1317 }
1318
1319 #[must_use]
1320 pub fn input_feature_channels(&self) -> usize {
1321 unsafe { ffi::mps_rnn_image_inference_layer_input_feature_channels(self.ptr) }
1322 }
1323
1324 #[must_use]
1325 pub fn output_feature_channels(&self) -> usize {
1326 unsafe { ffi::mps_rnn_image_inference_layer_output_feature_channels(self.ptr) }
1327 }
1328
1329 #[must_use]
1330 pub fn number_of_layers(&self) -> usize {
1331 unsafe { ffi::mps_rnn_image_inference_layer_number_of_layers(self.ptr) }
1332 }
1333
1334 #[must_use]
1335 pub fn recurrent_output_is_temporary(&self) -> bool {
1336 unsafe { ffi::mps_rnn_image_inference_layer_recurrent_output_is_temporary(self.ptr) }
1337 }
1338
1339 pub fn set_recurrent_output_is_temporary(&self, value: bool) {
1340 unsafe { ffi::mps_rnn_image_inference_layer_set_recurrent_output_is_temporary(self.ptr, value) };
1341 }
1342
1343 #[must_use]
1344 pub fn store_all_intermediate_states(&self) -> bool {
1345 unsafe { ffi::mps_rnn_image_inference_layer_store_all_intermediate_states(self.ptr) }
1346 }
1347
1348 pub fn set_store_all_intermediate_states(&self, value: bool) {
1349 unsafe { ffi::mps_rnn_image_inference_layer_set_store_all_intermediate_states(self.ptr, value) };
1350 }
1351
1352 #[must_use]
1353 pub fn bidirectional_combine_mode(&self) -> usize {
1354 unsafe { ffi::mps_rnn_image_inference_layer_bidirectional_combine_mode(self.ptr) }
1355 }
1356
1357 pub fn set_bidirectional_combine_mode(&self, value: usize) {
1358 unsafe { ffi::mps_rnn_image_inference_layer_set_bidirectional_combine_mode(self.ptr, value) };
1359 }
1360
1361 #[must_use]
1362 pub fn encode_sequence(
1363 &self,
1364 command_buffer: &CommandBuffer,
1365 source_images: &[&Image],
1366 destination_images: &[&Image],
1367 recurrent_input_state: Option<&RnnRecurrentImageState>,
1368 ) -> Option<RnnRecurrentImageState> {
1369 if source_images.len() != destination_images.len() {
1370 return None;
1371 }
1372 let source_handles: Vec<_> = source_images.iter().map(|image| image.as_ptr()).collect();
1373 let destination_handles: Vec<_> = destination_images.iter().map(|image| image.as_ptr()).collect();
1374 let source_ptr = if source_handles.is_empty() {
1375 ptr::null()
1376 } else {
1377 source_handles.as_ptr()
1378 };
1379 let destination_ptr = if destination_handles.is_empty() {
1380 ptr::null()
1381 } else {
1382 destination_handles.as_ptr()
1383 };
1384 let recurrent_input_ptr = recurrent_input_state.map_or(ptr::null_mut(), RnnRecurrentImageState::as_ptr);
1385 let ptr = unsafe {
1386 ffi::mps_rnn_image_inference_layer_encode_sequence(
1387 self.ptr,
1388 command_buffer.as_ptr(),
1389 source_images.len(),
1390 source_ptr,
1391 destination_ptr,
1392 recurrent_input_ptr,
1393 )
1394 };
1395 if ptr.is_null() {
1396 None
1397 } else {
1398 Some(RnnRecurrentImageState { ptr })
1399 }
1400 }
1401}