objc2_metal_performance_shaders/generated/MPSNeuralNetwork/
MPSRNNLayer.rs

1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5use objc2::__framework_prelude::*;
6use objc2_foundation::*;
7use objc2_metal::*;
8
9use crate::*;
10
11/// [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnsequencedirection?language=objc)
12// NS_ENUM
13#[repr(transparent)]
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct MPSRNNSequenceDirection(pub NSUInteger);
16impl MPSRNNSequenceDirection {
17    /// The input sequence is processed from index zero to array length minus one
18    #[doc(alias = "MPSRNNSequenceDirectionForward")]
19    pub const Forward: Self = Self(0);
20    /// The input sequence is processed from index array length minus one to zero
21    #[doc(alias = "MPSRNNSequenceDirectionBackward")]
22    pub const Backward: Self = Self(1);
23}
24
25unsafe impl Encode for MPSRNNSequenceDirection {
26    const ENCODING: Encoding = NSUInteger::ENCODING;
27}
28
29unsafe impl RefEncode for MPSRNNSequenceDirection {
30    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
31}
32
33/// [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnbidirectionalcombinemode?language=objc)
34// NS_ENUM
35#[repr(transparent)]
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
37pub struct MPSRNNBidirectionalCombineMode(pub NSUInteger);
38impl MPSRNNBidirectionalCombineMode {
39    /// The two sequences are kept separate
40    #[doc(alias = "MPSRNNBidirectionalCombineModeNone")]
41    pub const None: Self = Self(0);
42    /// The two sequences are summed together to form a single output
43    #[doc(alias = "MPSRNNBidirectionalCombineModeAdd")]
44    pub const Add: Self = Self(1);
45    /// The two sequences are concatenated together along the feature channels to form a single output
46    #[doc(alias = "MPSRNNBidirectionalCombineModeConcatenate")]
47    pub const Concatenate: Self = Self(2);
48}
49
50unsafe impl Encode for MPSRNNBidirectionalCombineMode {
51    const ENCODING: Encoding = NSUInteger::ENCODING;
52}
53
54unsafe impl RefEncode for MPSRNNBidirectionalCombineMode {
55    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
56}
57
58extern_class!(
59    /// Dependencies: This depends on Metal.framework
60    ///
61    /// The MPSRNNDescriptor specifies a Recursive neural network block/layer descriptor.
62    ///
63    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnndescriptor?language=objc)
64    #[unsafe(super(NSObject))]
65    #[derive(Debug, PartialEq, Eq, Hash)]
66    pub struct MPSRNNDescriptor;
67);
68
69extern_conformance!(
70    unsafe impl NSObjectProtocol for MPSRNNDescriptor {}
71);
72
73impl MPSRNNDescriptor {
74    extern_methods!(
75        /// The number of feature channels per pixel in the input image or number of rows in the input matrix.
76        #[unsafe(method(inputFeatureChannels))]
77        #[unsafe(method_family = none)]
78        pub unsafe fn inputFeatureChannels(&self) -> NSUInteger;
79
80        /// Setter for [`inputFeatureChannels`][Self::inputFeatureChannels].
81        #[unsafe(method(setInputFeatureChannels:))]
82        #[unsafe(method_family = none)]
83        pub unsafe fn setInputFeatureChannels(&self, input_feature_channels: NSUInteger);
84
85        /// The number of feature channels per pixel in the destination image or number of rows in the destination matrix.
86        #[unsafe(method(outputFeatureChannels))]
87        #[unsafe(method_family = none)]
88        pub unsafe fn outputFeatureChannels(&self) -> NSUInteger;
89
90        /// Setter for [`outputFeatureChannels`][Self::outputFeatureChannels].
91        #[unsafe(method(setOutputFeatureChannels:))]
92        #[unsafe(method_family = none)]
93        pub unsafe fn setOutputFeatureChannels(&self, output_feature_channels: NSUInteger);
94
95        /// if YES then use identity transformation for all weights (W, Wr, Wi, Wf, Wo, Wc) affecting input x_j in this layer,
96        /// even if said weights are specified as nil.
97        /// For example 'W_ij * x_j' is replaced by 'x_j' in formulae defined in
98        /// MPSRNNSingleGateDescriptor.Defaults to NO.
99        #[unsafe(method(useLayerInputUnitTransformMode))]
100        #[unsafe(method_family = none)]
101        pub unsafe fn useLayerInputUnitTransformMode(&self) -> bool;
102
103        /// Setter for [`useLayerInputUnitTransformMode`][Self::useLayerInputUnitTransformMode].
104        #[unsafe(method(setUseLayerInputUnitTransformMode:))]
105        #[unsafe(method_family = none)]
106        pub unsafe fn setUseLayerInputUnitTransformMode(
107            &self,
108            use_layer_input_unit_transform_mode: bool,
109        );
110
111        /// If YES, then
112        /// MPSRNNMatrixInferenceLayeruses 32-bit floating point numbers internally for weights when
113        /// computing matrix transformations. If NO, then 16-bit, half precision floating point numbers are used.
114        /// Currently
115        /// MPSRNNImageInferenceLayerignores this property and the convolution operations always
116        /// convert FP32 weights into FP16 for better performance.
117        /// Defaults to NO.
118        #[unsafe(method(useFloat32Weights))]
119        #[unsafe(method_family = none)]
120        pub unsafe fn useFloat32Weights(&self) -> bool;
121
122        /// Setter for [`useFloat32Weights`][Self::useFloat32Weights].
123        #[unsafe(method(setUseFloat32Weights:))]
124        #[unsafe(method_family = none)]
125        pub unsafe fn setUseFloat32Weights(&self, use_float32_weights: bool);
126
127        /// When the layer specified with this descriptor is used to process a sequence of inputs
128        /// by calling
129        ///
130        /// See: encodeBidirectionalSequenceToCommandBuffer then this parameter defines
131        /// in which direction the sequence is processed. The operation of the layer is:
132        /// (yt, ht, ct) = f(xt,ht-1,ct-1) for MPSRNNSequenceDirectionForward
133        /// and
134        /// (yt, ht, ct) = f(xt,ht+1,ct+1) for MPSRNNSequenceDirectionBackward, where
135        /// xt is the output of the previous layer that encodes in the same direction as this layer,
136        /// (or the input image or matrix if this is the first layer in stack with this direction).
137        ///
138        /// See: MPSRNNImageInferenceLayer and
139        ///
140        /// See: MPSRNNMatrixInferenceLayer.
141        #[unsafe(method(layerSequenceDirection))]
142        #[unsafe(method_family = none)]
143        pub unsafe fn layerSequenceDirection(&self) -> MPSRNNSequenceDirection;
144
145        /// Setter for [`layerSequenceDirection`][Self::layerSequenceDirection].
146        #[unsafe(method(setLayerSequenceDirection:))]
147        #[unsafe(method_family = none)]
148        pub unsafe fn setLayerSequenceDirection(
149            &self,
150            layer_sequence_direction: MPSRNNSequenceDirection,
151        );
152    );
153}
154
155/// Methods declared on superclass `NSObject`.
156impl MPSRNNDescriptor {
157    extern_methods!(
158        #[unsafe(method(init))]
159        #[unsafe(method_family = init)]
160        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
161
162        #[unsafe(method(new))]
163        #[unsafe(method_family = new)]
164        pub unsafe fn new() -> Retained<Self>;
165    );
166}
167
168extern_class!(
169    /// Dependencies: This depends on Metal.framework
170    ///
171    /// The MPSRNNSingleGateDescriptor specifies a simple recurrent block/layer descriptor.
172    /// The RNN layer initialized with a MPSRNNSingleGateDescriptor transforms the input data (image or matrix),
173    /// and previous output with a set of filters, each producing one feature map in the new output data.
174    /// The user may provide the RNN unit a single input or a sequence of inputs.
175    ///
176    /// Description of operation:
177    ///
178    /// Let x_j be the input data (at time index t of sequence,
179    /// j index containing quadruplet: batch index, x,y and feature index (x=y=0 for matrices)).
180    /// Let h0_j be the recurrent input (previous output) data from previous time step (at time index t-1 of sequence).
181    /// Let h1_i be the output data produced at this time step.
182    ///
183    /// Let W_ij, U_ij be the weights for input and recurrent input data respectively
184    /// Let b_i be a bias term
185    ///
186    /// Let gi(x) be a neuron activation function
187    ///
188    /// Then the new output image h1_i data is computed as follows:
189    ///
190    /// h1_i = gi( W_ij * x_j + U_ij * h0_j  + b_i )
191    ///
192    /// The '*' stands for convolution (see
193    /// MPSRNNImageInferenceLayer)or matrix-vector/matrix multiplication
194    /// (see
195    /// MPSRNNMatrixInferenceLayer).Summation is over index j (except for the batch index), but there is no summation over
196    /// repeated index i - the output index.
197    /// Note that for validity all intermediate images have to be of same size and the U matrix has to be square
198    /// (ie. outputFeatureChannels == inputFeatureChannels in those). Also the bias terms are scalars wrt. spatial dimensions.
199    ///
200    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnsinglegatedescriptor?language=objc)
201    #[unsafe(super(MPSRNNDescriptor, NSObject))]
202    #[derive(Debug, PartialEq, Eq, Hash)]
203    pub struct MPSRNNSingleGateDescriptor;
204);
205
206extern_conformance!(
207    unsafe impl NSObjectProtocol for MPSRNNSingleGateDescriptor {}
208);
209
210impl MPSRNNSingleGateDescriptor {
211    extern_methods!(
212        #[cfg(feature = "MPSCNNConvolution")]
213        /// Contains weights 'W_ij', bias 'b_i' and neuron 'gi' from the simple RNN layer formula.
214        /// If nil then assumed zero weights, bias and no neuron (identity mapping). Defaults to nil.
215        #[unsafe(method(inputWeights))]
216        #[unsafe(method_family = none)]
217        pub unsafe fn inputWeights(
218            &self,
219        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
220
221        #[cfg(feature = "MPSCNNConvolution")]
222        /// Setter for [`inputWeights`][Self::inputWeights].
223        #[unsafe(method(setInputWeights:))]
224        #[unsafe(method_family = none)]
225        pub unsafe fn setInputWeights(
226            &self,
227            input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
228        );
229
230        #[cfg(feature = "MPSCNNConvolution")]
231        /// Contains weights 'U_ij' from the simple RNN layer formula.
232        /// If nil then assumed zero weights. Defaults to nil.
233        #[unsafe(method(recurrentWeights))]
234        #[unsafe(method_family = none)]
235        pub unsafe fn recurrentWeights(
236            &self,
237        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
238
239        #[cfg(feature = "MPSCNNConvolution")]
240        /// Setter for [`recurrentWeights`][Self::recurrentWeights].
241        #[unsafe(method(setRecurrentWeights:))]
242        #[unsafe(method_family = none)]
243        pub unsafe fn setRecurrentWeights(
244            &self,
245            recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
246        );
247
248        /// Creates a MPSRNNSingleGateDescriptor
249        ///
250        /// Parameter `inputFeatureChannels`: The number of feature channels in the input image/matrix. Must be >= 1.
251        ///
252        /// Parameter `outputFeatureChannels`: The number of feature channels in the output image/matrix. Must be >= 1.
253        ///
254        /// Returns: A valid MPSRNNSingleGateDescriptor object or nil, if failure.
255        #[unsafe(method(createRNNSingleGateDescriptorWithInputFeatureChannels:outputFeatureChannels:))]
256        #[unsafe(method_family = none)]
257        pub unsafe fn createRNNSingleGateDescriptorWithInputFeatureChannels_outputFeatureChannels(
258            input_feature_channels: NSUInteger,
259            output_feature_channels: NSUInteger,
260        ) -> Retained<Self>;
261    );
262}
263
264/// Methods declared on superclass `NSObject`.
265impl MPSRNNSingleGateDescriptor {
266    extern_methods!(
267        #[unsafe(method(init))]
268        #[unsafe(method_family = init)]
269        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
270
271        #[unsafe(method(new))]
272        #[unsafe(method_family = new)]
273        pub unsafe fn new() -> Retained<Self>;
274    );
275}
276
277extern_class!(
278    /// Dependencies: This depends on Metal.framework
279    ///
280    /// The MPSGRUDescriptor specifies a GRU (Gated Recurrent Unit) block/layer descriptor.
281    /// The RNN layer initialized with a MPSGRUDescriptor transforms the input data (image or matrix),
282    /// and previous output with a set of filters, each producing one feature map in
283    /// the output data according to the Gated unit formulae detailed below.
284    /// The user may provide the GRU unit a single input or a sequence of inputs. The layer also supports
285    /// p-norm gating (Detailed in: https://arxiv.org/abs/1608.03639 ).
286    ///
287    /// Description of operation:
288    ///
289    /// Let x_j be the input data (at time index t of sequence,
290    /// j index containing quadruplet: batch index, x,y and feature index (x=y=0 for matrices)).
291    /// Let h0_j be the recurrent input (previous output) data from previous time step (at time index t-1 of sequence).
292    /// Let h_i be the proposed new output.
293    /// Let h1_i be the output data produced at this time step.
294    ///
295    /// Let Wz_ij, Uz_ij, be the input gate weights for input and recurrent input data respectively
296    /// Let bi_i be the bias for the input gate
297    ///
298    /// Let Wr_ij, Ur_ij be the recurrent gate weights for input and recurrent input data respectively
299    /// Let br_i be the bias for the recurrent gate
300    ///
301    /// Let Wh_ij, Uh_ij, Vh_ij, be the output gate weights for input, recurrent gate and input gate respectively
302    /// Let bh_i be the bias for the output gate
303    ///
304    /// Let gz(x), gr(x), gh(x) be the neuron activation function for the input, recurrent and output gates
305    /// Let p > 0 be a scalar variable (typicall p >= 1.0) that defines the p-norm gating norm value.
306    ///
307    /// Then the output of the Gated Recurrent Unit layer is computed as follows:
308    ///
309    /// z_i = gz(  Wz_ij * x_j  +  Uz_ij * h0_j  +  bz_i  )
310    /// r_i = gr(  Wr_ij * x_j  +  Ur_ij * h0_j  +  br_i  )
311    /// c_i =      Uh_ij * (r_j h0_j)  +  Vh_ij * (z_j h0_j)
312    /// h_i = gh(  Wh_ij * x_j  + c_i + bh_i  )
313    ///
314    /// h1_i = ( 1 - z_i ^ p)^(1/p) h_i + z_i h0_i
315    ///
316    /// The '*' stands for convolution (see
317    /// MPSRNNImageInferenceLayer)or matrix-vector/matrix multiplication
318    /// (see
319    /// MPSRNNMatrixInferenceLayer).Summation is over index j (except for the batch index), but there is no summation over
320    /// repeated index i - the output index.
321    /// Note that for validity all intermediate images have to be of same size and all U and V matrices have to be square
322    /// (ie. outputFeatureChannels == inputFeatureChannels in those). Also the bias terms are scalars wrt. spatial dimensions.
323    /// The conventional GRU block is achieved by setting Vh = 0 (nil) and the so-called Minimal Gated Unit is achieved with Uh = 0.
324    /// (The Minimal Gated Unit is detailed in: https://arxiv.org/abs/1603.09420 and there they call z_i the value of the forget gate).
325    ///
326    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsgrudescriptor?language=objc)
327    #[unsafe(super(MPSRNNDescriptor, NSObject))]
328    #[derive(Debug, PartialEq, Eq, Hash)]
329    pub struct MPSGRUDescriptor;
330);
331
332extern_conformance!(
333    unsafe impl NSObjectProtocol for MPSGRUDescriptor {}
334);
335
336impl MPSGRUDescriptor {
337    extern_methods!(
338        #[cfg(feature = "MPSCNNConvolution")]
339        /// Contains weights 'Wz_ij', bias 'bz_i' and neuron 'gz' from the GRU formula.
340        /// If nil then assumed zero weights, bias and no neuron (identity mapping). Defaults to nil.
341        #[unsafe(method(inputGateInputWeights))]
342        #[unsafe(method_family = none)]
343        pub unsafe fn inputGateInputWeights(
344            &self,
345        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
346
347        #[cfg(feature = "MPSCNNConvolution")]
348        /// Setter for [`inputGateInputWeights`][Self::inputGateInputWeights].
349        #[unsafe(method(setInputGateInputWeights:))]
350        #[unsafe(method_family = none)]
351        pub unsafe fn setInputGateInputWeights(
352            &self,
353            input_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
354        );
355
356        #[cfg(feature = "MPSCNNConvolution")]
357        /// Contains weights 'Uz_ij' from the GRU formula.
358        /// If nil then assumed zero weights. Defaults to nil.
359        #[unsafe(method(inputGateRecurrentWeights))]
360        #[unsafe(method_family = none)]
361        pub unsafe fn inputGateRecurrentWeights(
362            &self,
363        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
364
365        #[cfg(feature = "MPSCNNConvolution")]
366        /// Setter for [`inputGateRecurrentWeights`][Self::inputGateRecurrentWeights].
367        #[unsafe(method(setInputGateRecurrentWeights:))]
368        #[unsafe(method_family = none)]
369        pub unsafe fn setInputGateRecurrentWeights(
370            &self,
371            input_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
372        );
373
374        #[cfg(feature = "MPSCNNConvolution")]
375        /// Contains weights 'Wr_ij', bias 'br_i' and neuron 'gr' from the GRU formula.
376        /// If nil then assumed zero weights, bias and no neuron (identity mapping).Defaults to nil.
377        #[unsafe(method(recurrentGateInputWeights))]
378        #[unsafe(method_family = none)]
379        pub unsafe fn recurrentGateInputWeights(
380            &self,
381        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
382
383        #[cfg(feature = "MPSCNNConvolution")]
384        /// Setter for [`recurrentGateInputWeights`][Self::recurrentGateInputWeights].
385        #[unsafe(method(setRecurrentGateInputWeights:))]
386        #[unsafe(method_family = none)]
387        pub unsafe fn setRecurrentGateInputWeights(
388            &self,
389            recurrent_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
390        );
391
392        #[cfg(feature = "MPSCNNConvolution")]
393        /// Contains weights 'Ur_ij' from the GRU formula.
394        /// If nil then assumed zero weights.Defaults to nil.
395        #[unsafe(method(recurrentGateRecurrentWeights))]
396        #[unsafe(method_family = none)]
397        pub unsafe fn recurrentGateRecurrentWeights(
398            &self,
399        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
400
401        #[cfg(feature = "MPSCNNConvolution")]
402        /// Setter for [`recurrentGateRecurrentWeights`][Self::recurrentGateRecurrentWeights].
403        #[unsafe(method(setRecurrentGateRecurrentWeights:))]
404        #[unsafe(method_family = none)]
405        pub unsafe fn setRecurrentGateRecurrentWeights(
406            &self,
407            recurrent_gate_recurrent_weights: Option<
408                &ProtocolObject<dyn MPSCNNConvolutionDataSource>,
409            >,
410        );
411
412        #[cfg(feature = "MPSCNNConvolution")]
413        /// Contains weights 'Wh_ij', bias 'bh_i' and neuron 'gh' from the GRU formula.
414        /// If nil then assumed zero weights, bias and no neuron (identity mapping).Defaults to nil.
415        #[unsafe(method(outputGateInputWeights))]
416        #[unsafe(method_family = none)]
417        pub unsafe fn outputGateInputWeights(
418            &self,
419        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
420
421        #[cfg(feature = "MPSCNNConvolution")]
422        /// Setter for [`outputGateInputWeights`][Self::outputGateInputWeights].
423        #[unsafe(method(setOutputGateInputWeights:))]
424        #[unsafe(method_family = none)]
425        pub unsafe fn setOutputGateInputWeights(
426            &self,
427            output_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
428        );
429
430        #[cfg(feature = "MPSCNNConvolution")]
431        /// Contains weights 'Uh_ij' from the GRU formula.
432        /// If nil then assumed zero weights. Defaults to nil.
433        #[unsafe(method(outputGateRecurrentWeights))]
434        #[unsafe(method_family = none)]
435        pub unsafe fn outputGateRecurrentWeights(
436            &self,
437        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
438
439        #[cfg(feature = "MPSCNNConvolution")]
440        /// Setter for [`outputGateRecurrentWeights`][Self::outputGateRecurrentWeights].
441        #[unsafe(method(setOutputGateRecurrentWeights:))]
442        #[unsafe(method_family = none)]
443        pub unsafe fn setOutputGateRecurrentWeights(
444            &self,
445            output_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
446        );
447
448        #[cfg(feature = "MPSCNNConvolution")]
449        /// Contains weights 'Vh_ij' - can be used to implement the "Minimally Gated Unit".
450        /// If nil then assumed zero weights. Defaults to nil.
451        #[unsafe(method(outputGateInputGateWeights))]
452        #[unsafe(method_family = none)]
453        pub unsafe fn outputGateInputGateWeights(
454            &self,
455        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
456
457        #[cfg(feature = "MPSCNNConvolution")]
458        /// Setter for [`outputGateInputGateWeights`][Self::outputGateInputGateWeights].
459        #[unsafe(method(setOutputGateInputGateWeights:))]
460        #[unsafe(method_family = none)]
461        pub unsafe fn setOutputGateInputGateWeights(
462            &self,
463            output_gate_input_gate_weights: Option<
464                &ProtocolObject<dyn MPSCNNConvolutionDataSource>,
465            >,
466        );
467
468        /// The p-norm gating norm value as specified by the GRU formulae. Defaults to 1.0f.
469        #[unsafe(method(gatePnormValue))]
470        #[unsafe(method_family = none)]
471        pub unsafe fn gatePnormValue(&self) -> c_float;
472
473        /// Setter for [`gatePnormValue`][Self::gatePnormValue].
474        #[unsafe(method(setGatePnormValue:))]
475        #[unsafe(method_family = none)]
476        pub unsafe fn setGatePnormValue(&self, gate_pnorm_value: c_float);
477
478        /// If YES then the GRU-block output formula is changed to:
479        /// h1_i = ( 1 - z_i ^ p)^(1/p) h0_i + z_i h_i.
480        /// Defaults to NO.
481        #[unsafe(method(flipOutputGates))]
482        #[unsafe(method_family = none)]
483        pub unsafe fn flipOutputGates(&self) -> bool;
484
485        /// Setter for [`flipOutputGates`][Self::flipOutputGates].
486        #[unsafe(method(setFlipOutputGates:))]
487        #[unsafe(method_family = none)]
488        pub unsafe fn setFlipOutputGates(&self, flip_output_gates: bool);
489
490        /// Creates a GRU descriptor.
491        ///
492        /// Parameter `inputFeatureChannels`: The number of feature channels in the input image/matrix. Must be >= 1.
493        ///
494        /// Parameter `outputFeatureChannels`: The number of feature channels in the output image/matrix. Must be >= 1.
495        ///
496        /// Returns: A valid MPSGRUDescriptor object or nil, if failure.
497        #[unsafe(method(createGRUDescriptorWithInputFeatureChannels:outputFeatureChannels:))]
498        #[unsafe(method_family = none)]
499        pub unsafe fn createGRUDescriptorWithInputFeatureChannels_outputFeatureChannels(
500            input_feature_channels: NSUInteger,
501            output_feature_channels: NSUInteger,
502        ) -> Retained<Self>;
503    );
504}
505
506/// Methods declared on superclass `NSObject`.
507impl MPSGRUDescriptor {
508    extern_methods!(
509        #[unsafe(method(init))]
510        #[unsafe(method_family = init)]
511        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
512
513        #[unsafe(method(new))]
514        #[unsafe(method_family = new)]
515        pub unsafe fn new() -> Retained<Self>;
516    );
517}
518
519extern_class!(
520    /// Dependencies: This depends on Metal.framework
521    ///
522    /// The MPSLSTMDescriptor specifies a LSTM block/layer descriptor.
523    /// The RNN layer initialized with a MPSLSTMDescriptor transforms the input data (image or matrix),
524    /// the memory cell data and previous output with a set of filters, each producing one feature map in
525    /// the output data and memory cell, according to the LSTM formulae detailed below.
526    /// The user may provide the LSTM unit a single input or a sequence of inputs.
527    ///
528    /// Description of operation:
529    ///
530    /// Let x_j be the input data (at time index t of sequence,
531    /// j index containing quadruplet: batch index, x,y and feature index (x=y=0 for matrices)).
532    /// Let h0_j be the recurrent input (previous output) data from previous time step (at time index t-1 of sequence).
533    /// Let h1_i be the output data produced at this time step.
534    /// Let c0_j be the previous memory cell data (at time index t-1 of sequence).
535    /// Let c1_i be the new memory cell data (at time index t-1 of sequence).
536    ///
537    /// Let Wi_ij, Ui_ij, Vi_ij, be the input gate weights for input, recurrent input and memory cell (peephole) data respectively
538    /// Let bi_i be the bias for the input gate
539    ///
540    /// Let Wf_ij, Uf_ij, Vf_ij, be the forget gate weights for input, recurrent input and memory cell data respectively
541    /// Let bf_i be the bias for the forget gate
542    ///
543    /// Let Wo_ij, Uo_ij, Vo_ij, be the output gate weights for input, recurrent input and memory cell data respectively
544    /// Let bo_i be the bias for the output gate
545    ///
546    /// Let Wc_ij, Uc_ij, Vc_ij, be the memory cell gate weights for input, recurrent input and memory cell data respectively
547    /// Let bc_i be the bias for the memory cell gate
548    ///
549    /// Let gi(x), gf(x), go(x), gc(x) be neuron activation function for the input, forget, output gate and memory cell gate
550    /// Let gh(x) be the activation function applied to result memory cell data
551    ///
552    /// Then the new memory cell data c1_j and output image h1_i are computed as follows:
553    ///
554    /// I_i = gi(  Wi_ij * x_j  +  Ui_ij * h0_j  +  Vi_ij * c0_j  + bi_i  )
555    /// F_i = gf(  Wf_ij * x_j  +  Uf_ij * h0_j  +  Vf_ij * c0_j  + bf_i  )
556    /// C_i = gc(  Wc_ij * x_j  +  Uc_ij * h0_j  +  Vc_ij * c0_j  + bc_i  )
557    ///
558    /// c1_i = F_i c0_i  +  I_i C_i
559    ///
560    /// O_i = go(  Wo_ij * x_j  +  Uo_ij * h0_j  +  Vo_ij * c1_j  + bo_i  )
561    ///
562    /// h1_i = O_i gh( c1_i )
563    ///
564    /// The '*' stands for convolution (see
565    /// MPSRNNImageInferenceLayer)or matrix-vector/matrix multiplication
566    /// (see
567    /// MPSRNNMatrixInferenceLayer).Summation is over index j (except for the batch index), but there is no summation over
568    /// repeated index i - the output index.
569    /// Note that for validity all intermediate images have to be of same size and all U and V matrices have to be square
570    /// (ie. outputFeatureChannels == inputFeatureChannels in those). Also the bias terms are scalars wrt. spatial dimensions.
571    ///
572    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpslstmdescriptor?language=objc)
573    #[unsafe(super(MPSRNNDescriptor, NSObject))]
574    #[derive(Debug, PartialEq, Eq, Hash)]
575    pub struct MPSLSTMDescriptor;
576);
577
578extern_conformance!(
579    unsafe impl NSObjectProtocol for MPSLSTMDescriptor {}
580);
581
582impl MPSLSTMDescriptor {
583    extern_methods!(
584        /// If YES, then the 'peephole' weight matrices will be diagonal matrices represented as
585        /// vectors of length the number of features in memory cells, that will be multiplied pointwise
586        /// with the peephole matrix or image in order to achieve the diagonal (nonmixing) update.
587        /// Defaults to NO.
588        #[unsafe(method(memoryWeightsAreDiagonal))]
589        #[unsafe(method_family = none)]
590        pub unsafe fn memoryWeightsAreDiagonal(&self) -> bool;
591
592        /// Setter for [`memoryWeightsAreDiagonal`][Self::memoryWeightsAreDiagonal].
593        #[unsafe(method(setMemoryWeightsAreDiagonal:))]
594        #[unsafe(method_family = none)]
595        pub unsafe fn setMemoryWeightsAreDiagonal(&self, memory_weights_are_diagonal: bool);
596
597        #[cfg(feature = "MPSCNNConvolution")]
598        /// Contains weights 'Wi_ij', bias 'bi_i' and neuron 'gi' from the LSTM formula.
599        /// If nil then assumed zero weights, bias and no neuron (identity mapping). Defaults to nil.
600        #[unsafe(method(inputGateInputWeights))]
601        #[unsafe(method_family = none)]
602        pub unsafe fn inputGateInputWeights(
603            &self,
604        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
605
606        #[cfg(feature = "MPSCNNConvolution")]
607        /// Setter for [`inputGateInputWeights`][Self::inputGateInputWeights].
608        #[unsafe(method(setInputGateInputWeights:))]
609        #[unsafe(method_family = none)]
610        pub unsafe fn setInputGateInputWeights(
611            &self,
612            input_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
613        );
614
615        #[cfg(feature = "MPSCNNConvolution")]
616        /// Contains weights 'Ui_ij' from the LSTM formula.
617        /// If nil then assumed zero weights. Defaults to nil.
618        #[unsafe(method(inputGateRecurrentWeights))]
619        #[unsafe(method_family = none)]
620        pub unsafe fn inputGateRecurrentWeights(
621            &self,
622        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
623
624        #[cfg(feature = "MPSCNNConvolution")]
625        /// Setter for [`inputGateRecurrentWeights`][Self::inputGateRecurrentWeights].
626        #[unsafe(method(setInputGateRecurrentWeights:))]
627        #[unsafe(method_family = none)]
628        pub unsafe fn setInputGateRecurrentWeights(
629            &self,
630            input_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
631        );
632
633        #[cfg(feature = "MPSCNNConvolution")]
634        /// Contains weights 'Vi_ij' - the 'peephole' weights - from the LSTM formula.
635        /// if YES == memoryWeightsAreDiagonal, then the number of weights used is the number of features
636        /// in the memory cell image/matrix.
637        /// If nil then assumed zero weights. Defaults to nil.
638        #[unsafe(method(inputGateMemoryWeights))]
639        #[unsafe(method_family = none)]
640        pub unsafe fn inputGateMemoryWeights(
641            &self,
642        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
643
644        #[cfg(feature = "MPSCNNConvolution")]
645        /// Setter for [`inputGateMemoryWeights`][Self::inputGateMemoryWeights].
646        #[unsafe(method(setInputGateMemoryWeights:))]
647        #[unsafe(method_family = none)]
648        pub unsafe fn setInputGateMemoryWeights(
649            &self,
650            input_gate_memory_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
651        );
652
653        #[cfg(feature = "MPSCNNConvolution")]
654        /// Contains weights 'Wf_ij', bias 'bf_i' and neuron 'gf' from the LSTM formula.
655        /// If nil then assumed zero weights, bias and no neuron (identity mapping).Defaults to nil.
656        #[unsafe(method(forgetGateInputWeights))]
657        #[unsafe(method_family = none)]
658        pub unsafe fn forgetGateInputWeights(
659            &self,
660        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
661
662        #[cfg(feature = "MPSCNNConvolution")]
663        /// Setter for [`forgetGateInputWeights`][Self::forgetGateInputWeights].
664        #[unsafe(method(setForgetGateInputWeights:))]
665        #[unsafe(method_family = none)]
666        pub unsafe fn setForgetGateInputWeights(
667            &self,
668            forget_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
669        );
670
671        #[cfg(feature = "MPSCNNConvolution")]
672        /// Contains weights 'Uf_ij' from the LSTM formula.
673        /// If nil then assumed zero weights. Defaults to nil.
674        #[unsafe(method(forgetGateRecurrentWeights))]
675        #[unsafe(method_family = none)]
676        pub unsafe fn forgetGateRecurrentWeights(
677            &self,
678        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
679
680        #[cfg(feature = "MPSCNNConvolution")]
681        /// Setter for [`forgetGateRecurrentWeights`][Self::forgetGateRecurrentWeights].
682        #[unsafe(method(setForgetGateRecurrentWeights:))]
683        #[unsafe(method_family = none)]
684        pub unsafe fn setForgetGateRecurrentWeights(
685            &self,
686            forget_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
687        );
688
689        #[cfg(feature = "MPSCNNConvolution")]
690        /// Contains weights 'Vf_ij' - the 'peephole' weights - from the LSTM formula.
691        /// if YES == memoryWeightsAreDiagonal, then the number of weights used is the number of features
692        /// in the memory cell image/matrix.
693        /// If nil then assumed zero weights. Defaults to nil.
694        #[unsafe(method(forgetGateMemoryWeights))]
695        #[unsafe(method_family = none)]
696        pub unsafe fn forgetGateMemoryWeights(
697            &self,
698        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
699
700        #[cfg(feature = "MPSCNNConvolution")]
701        /// Setter for [`forgetGateMemoryWeights`][Self::forgetGateMemoryWeights].
702        #[unsafe(method(setForgetGateMemoryWeights:))]
703        #[unsafe(method_family = none)]
704        pub unsafe fn setForgetGateMemoryWeights(
705            &self,
706            forget_gate_memory_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
707        );
708
709        #[cfg(feature = "MPSCNNConvolution")]
710        /// Contains weights 'Wo_ij', bias 'bo_i' and neuron 'go' from the LSTM formula.
711        /// If nil then assumed zero weights, bias and no neuron (identity mapping). Defaults to nil.
712        #[unsafe(method(outputGateInputWeights))]
713        #[unsafe(method_family = none)]
714        pub unsafe fn outputGateInputWeights(
715            &self,
716        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
717
718        #[cfg(feature = "MPSCNNConvolution")]
719        /// Setter for [`outputGateInputWeights`][Self::outputGateInputWeights].
720        #[unsafe(method(setOutputGateInputWeights:))]
721        #[unsafe(method_family = none)]
722        pub unsafe fn setOutputGateInputWeights(
723            &self,
724            output_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
725        );
726
727        #[cfg(feature = "MPSCNNConvolution")]
728        /// Contains weights 'Uo_ij' from the LSTM formula.
729        /// If nil then assumed zero weights. Defaults to nil.
730        #[unsafe(method(outputGateRecurrentWeights))]
731        #[unsafe(method_family = none)]
732        pub unsafe fn outputGateRecurrentWeights(
733            &self,
734        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
735
736        #[cfg(feature = "MPSCNNConvolution")]
737        /// Setter for [`outputGateRecurrentWeights`][Self::outputGateRecurrentWeights].
738        #[unsafe(method(setOutputGateRecurrentWeights:))]
739        #[unsafe(method_family = none)]
740        pub unsafe fn setOutputGateRecurrentWeights(
741            &self,
742            output_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
743        );
744
745        #[cfg(feature = "MPSCNNConvolution")]
746        /// Contains weights 'Vo_ij' - the 'peephole' weights - from the LSTM.
747        /// if YES == memoryWeightsAreDiagonal, then the number of weights used is the number of features
748        /// in the memory cell image/matrix.
749        /// If nil then assumed zero weights. Defaults to nil.
750        #[unsafe(method(outputGateMemoryWeights))]
751        #[unsafe(method_family = none)]
752        pub unsafe fn outputGateMemoryWeights(
753            &self,
754        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
755
756        #[cfg(feature = "MPSCNNConvolution")]
757        /// Setter for [`outputGateMemoryWeights`][Self::outputGateMemoryWeights].
758        #[unsafe(method(setOutputGateMemoryWeights:))]
759        #[unsafe(method_family = none)]
760        pub unsafe fn setOutputGateMemoryWeights(
761            &self,
762            output_gate_memory_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
763        );
764
765        #[cfg(feature = "MPSCNNConvolution")]
766        /// Contains weights 'Wc_ij', bias 'bc_i' and neuron 'gc' from the LSTM formula.
767        /// If nil then assumed zero weights, bias and no neuron (identity mapping). Defaults to nil.
768        #[unsafe(method(cellGateInputWeights))]
769        #[unsafe(method_family = none)]
770        pub unsafe fn cellGateInputWeights(
771            &self,
772        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
773
774        #[cfg(feature = "MPSCNNConvolution")]
775        /// Setter for [`cellGateInputWeights`][Self::cellGateInputWeights].
776        #[unsafe(method(setCellGateInputWeights:))]
777        #[unsafe(method_family = none)]
778        pub unsafe fn setCellGateInputWeights(
779            &self,
780            cell_gate_input_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
781        );
782
783        #[cfg(feature = "MPSCNNConvolution")]
784        /// Contains weights 'Uc_ij' from the LSTM formula.
785        /// If nil then assumed zero weights. Defaults to nil.
786        #[unsafe(method(cellGateRecurrentWeights))]
787        #[unsafe(method_family = none)]
788        pub unsafe fn cellGateRecurrentWeights(
789            &self,
790        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
791
792        #[cfg(feature = "MPSCNNConvolution")]
793        /// Setter for [`cellGateRecurrentWeights`][Self::cellGateRecurrentWeights].
794        #[unsafe(method(setCellGateRecurrentWeights:))]
795        #[unsafe(method_family = none)]
796        pub unsafe fn setCellGateRecurrentWeights(
797            &self,
798            cell_gate_recurrent_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
799        );
800
801        #[cfg(feature = "MPSCNNConvolution")]
802        /// Contains weights 'Vc_ij' - the 'peephole' weights - from the LSTM formula.
803        /// if YES == memoryWeightsAreDiagonal, then the number of weights used is the number of features
804        /// in the memory cell image/matrix.
805        /// If nil then assumed zero weights. Defaults to nil.
806        #[unsafe(method(cellGateMemoryWeights))]
807        #[unsafe(method_family = none)]
808        pub unsafe fn cellGateMemoryWeights(
809            &self,
810        ) -> Option<Retained<ProtocolObject<dyn MPSCNNConvolutionDataSource>>>;
811
812        #[cfg(feature = "MPSCNNConvolution")]
813        /// Setter for [`cellGateMemoryWeights`][Self::cellGateMemoryWeights].
814        #[unsafe(method(setCellGateMemoryWeights:))]
815        #[unsafe(method_family = none)]
816        pub unsafe fn setCellGateMemoryWeights(
817            &self,
818            cell_gate_memory_weights: Option<&ProtocolObject<dyn MPSCNNConvolutionDataSource>>,
819        );
820
821        #[cfg(feature = "MPSCNNNeuronType")]
822        /// Neuron type definition for 'gh', see
823        /// MPSCNNNeuronType.Defaults to MPSCNNNeuronTypeTanH.
824        #[unsafe(method(cellToOutputNeuronType))]
825        #[unsafe(method_family = none)]
826        pub unsafe fn cellToOutputNeuronType(&self) -> MPSCNNNeuronType;
827
828        #[cfg(feature = "MPSCNNNeuronType")]
829        /// Setter for [`cellToOutputNeuronType`][Self::cellToOutputNeuronType].
830        #[unsafe(method(setCellToOutputNeuronType:))]
831        #[unsafe(method_family = none)]
832        pub unsafe fn setCellToOutputNeuronType(
833            &self,
834            cell_to_output_neuron_type: MPSCNNNeuronType,
835        );
836
837        /// Neuron parameter A for 'gh'. Defaults to 1.0f.
838        #[unsafe(method(cellToOutputNeuronParamA))]
839        #[unsafe(method_family = none)]
840        pub unsafe fn cellToOutputNeuronParamA(&self) -> c_float;
841
842        /// Setter for [`cellToOutputNeuronParamA`][Self::cellToOutputNeuronParamA].
843        #[unsafe(method(setCellToOutputNeuronParamA:))]
844        #[unsafe(method_family = none)]
845        pub unsafe fn setCellToOutputNeuronParamA(&self, cell_to_output_neuron_param_a: c_float);
846
847        /// Neuron parameter B for 'gh'. Defaults to 1.0f.
848        #[unsafe(method(cellToOutputNeuronParamB))]
849        #[unsafe(method_family = none)]
850        pub unsafe fn cellToOutputNeuronParamB(&self) -> c_float;
851
852        /// Setter for [`cellToOutputNeuronParamB`][Self::cellToOutputNeuronParamB].
853        #[unsafe(method(setCellToOutputNeuronParamB:))]
854        #[unsafe(method_family = none)]
855        pub unsafe fn setCellToOutputNeuronParamB(&self, cell_to_output_neuron_param_b: c_float);
856
857        /// Neuron parameter C for 'gh'. Defaults to 1.0f.
858        #[unsafe(method(cellToOutputNeuronParamC))]
859        #[unsafe(method_family = none)]
860        pub unsafe fn cellToOutputNeuronParamC(&self) -> c_float;
861
862        /// Setter for [`cellToOutputNeuronParamC`][Self::cellToOutputNeuronParamC].
863        #[unsafe(method(setCellToOutputNeuronParamC:))]
864        #[unsafe(method_family = none)]
865        pub unsafe fn setCellToOutputNeuronParamC(&self, cell_to_output_neuron_param_c: c_float);
866
867        /// Creates a LSTM descriptor.
868        ///
869        /// Parameter `inputFeatureChannels`: The number of feature channels in the input image/matrix. Must be >= 1.
870        ///
871        /// Parameter `outputFeatureChannels`: The number of feature channels in the output image/matrix. Must be >= 1.
872        ///
873        /// Returns: A valid MPSNNLSTMDescriptor object or nil, if failure.
874        #[unsafe(method(createLSTMDescriptorWithInputFeatureChannels:outputFeatureChannels:))]
875        #[unsafe(method_family = none)]
876        pub unsafe fn createLSTMDescriptorWithInputFeatureChannels_outputFeatureChannels(
877            input_feature_channels: NSUInteger,
878            output_feature_channels: NSUInteger,
879        ) -> Retained<Self>;
880    );
881}
882
883/// Methods declared on superclass `NSObject`.
884impl MPSLSTMDescriptor {
885    extern_methods!(
886        #[unsafe(method(init))]
887        #[unsafe(method_family = init)]
888        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
889
890        #[unsafe(method(new))]
891        #[unsafe(method_family = new)]
892        pub unsafe fn new() -> Retained<Self>;
893    );
894}
895
896extern_class!(
897    /// Dependencies: This depends on Metal.framework
898    ///
899    /// This class holds all the data that is passed from one sequence iteration of the image-based RNN layer (stack) to the next.
900    ///
901    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnrecurrentimagestate?language=objc)
902    #[unsafe(super(MPSState, NSObject))]
903    #[derive(Debug, PartialEq, Eq, Hash)]
904    #[cfg(all(feature = "MPSCore", feature = "MPSState"))]
905    pub struct MPSRNNRecurrentImageState;
906);
907
908#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
909extern_conformance!(
910    unsafe impl NSObjectProtocol for MPSRNNRecurrentImageState {}
911);
912
913#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
914impl MPSRNNRecurrentImageState {
915    extern_methods!(
916        #[cfg(feature = "MPSImage")]
917        /// Access the stored recurrent image data.
918        ///
919        /// Parameter `layerIndex`: Index of the layer whose to get - belongs to { 0, 1,...,
920        ///
921        /// See: numberOfLayers - 1 }
922        ///
923        /// Returns: For valid layerIndex the recurrent output image data, otherwise nil.
924        #[unsafe(method(getRecurrentOutputImageForLayerIndex:))]
925        #[unsafe(method_family = none)]
926        pub unsafe fn getRecurrentOutputImageForLayerIndex(
927            &self,
928            layer_index: NSUInteger,
929        ) -> Option<Retained<MPSImage>>;
930
931        #[cfg(feature = "MPSImage")]
932        /// Access the stored memory cell image data (if present).
933        ///
934        /// Parameter `layerIndex`: Index of the layer whose to get - belongs to { 0, 1,...,
935        ///
936        /// See: numberOfLayers - 1 }
937        ///
938        /// Returns: For valid layerIndex the memory cell image data, otherwise nil.
939        #[unsafe(method(getMemoryCellImageForLayerIndex:))]
940        #[unsafe(method_family = none)]
941        pub unsafe fn getMemoryCellImageForLayerIndex(
942            &self,
943            layer_index: NSUInteger,
944        ) -> Option<Retained<MPSImage>>;
945    );
946}
947
948/// Methods declared on superclass `MPSState`.
949#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
950impl MPSRNNRecurrentImageState {
951    extern_methods!(
952        /// Create a MPSState holding a temporary MTLBuffer
953        ///
954        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
955        ///
956        /// Parameter `bufferSize`: The size of the buffer in bytes
957        #[unsafe(method(temporaryStateWithCommandBuffer:bufferSize:))]
958        #[unsafe(method_family = none)]
959        pub unsafe fn temporaryStateWithCommandBuffer_bufferSize(
960            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
961            buffer_size: usize,
962        ) -> Retained<Self>;
963
964        /// Create a MPSState holding a temporary MTLTexture
965        ///
966        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
967        ///
968        /// Parameter `descriptor`: A descriptor for the new temporary texture
969        #[unsafe(method(temporaryStateWithCommandBuffer:textureDescriptor:))]
970        #[unsafe(method_family = none)]
971        pub unsafe fn temporaryStateWithCommandBuffer_textureDescriptor(
972            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
973            descriptor: &MTLTextureDescriptor,
974        ) -> Retained<Self>;
975
976        /// Create a new autoreleased temporary state object without underlying resource
977        ///
978        /// Parameter `cmdBuf`: The command buffer with which the temporary resource is associated
979        #[unsafe(method(temporaryStateWithCommandBuffer:))]
980        #[unsafe(method_family = none)]
981        pub unsafe fn temporaryStateWithCommandBuffer(
982            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
983        ) -> Retained<Self>;
984
985        #[unsafe(method(initWithDevice:bufferSize:))]
986        #[unsafe(method_family = init)]
987        pub unsafe fn initWithDevice_bufferSize(
988            this: Allocated<Self>,
989            device: &ProtocolObject<dyn MTLDevice>,
990            buffer_size: usize,
991        ) -> Retained<Self>;
992
993        #[unsafe(method(initWithDevice:textureDescriptor:))]
994        #[unsafe(method_family = init)]
995        pub unsafe fn initWithDevice_textureDescriptor(
996            this: Allocated<Self>,
997            device: &ProtocolObject<dyn MTLDevice>,
998            descriptor: &MTLTextureDescriptor,
999        ) -> Retained<Self>;
1000
1001        /// Create a MPSState with a non-temporary MTLResource
1002        ///
1003        /// Parameter `resource`: A MTLBuffer or MTLTexture. May be nil.
1004        ///
1005        /// # Safety
1006        ///
1007        /// - `resource` may need to be synchronized.
1008        /// - `resource` may be unretained, you must ensure it is kept alive while in use.
1009        #[unsafe(method(initWithResource:))]
1010        #[unsafe(method_family = init)]
1011        pub unsafe fn initWithResource(
1012            this: Allocated<Self>,
1013            resource: Option<&ProtocolObject<dyn MTLResource>>,
1014        ) -> Retained<Self>;
1015
1016        #[unsafe(method(init))]
1017        #[unsafe(method_family = init)]
1018        pub unsafe fn init(this: Allocated<Self>) -> Option<Retained<Self>>;
1019
1020        /// Initialize a non-temporary state to hold a number of textures and buffers
1021        ///
1022        /// The allocation of each resource will be deferred  until it is needed.
1023        /// This occurs when -resource or -resourceAtIndex: is called.
1024        ///
1025        /// Parameter `resourceList`: The list of resources to create.
1026        #[unsafe(method(initWithDevice:resourceList:))]
1027        #[unsafe(method_family = init)]
1028        pub unsafe fn initWithDevice_resourceList(
1029            this: Allocated<Self>,
1030            device: &ProtocolObject<dyn MTLDevice>,
1031            resource_list: &MPSStateResourceList,
1032        ) -> Retained<Self>;
1033
1034        /// Initialize a temporary state to hold a number of textures and buffers
1035        ///
1036        /// The textures occur first in sequence
1037        #[unsafe(method(temporaryStateWithCommandBuffer:resourceList:))]
1038        #[unsafe(method_family = none)]
1039        pub unsafe fn temporaryStateWithCommandBuffer_resourceList(
1040            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1041            resource_list: &MPSStateResourceList,
1042        ) -> Retained<Self>;
1043
1044        /// Create a state object with a list of MTLResources
1045        ///
1046        /// Because MPS prefers deferred allocation of resources
1047        /// your application should use -initWithTextures:bufferSizes:bufferCount:
1048        /// whenever possible. This method is useful for cases when the
1049        /// MTLResources must be initialized by the CPU.
1050        ///
1051        /// # Safety
1052        ///
1053        /// - `resources` generic may need to be synchronized.
1054        /// - `resources` generic may be unretained, you must ensure it is kept alive while in use.
1055        #[unsafe(method(initWithResources:))]
1056        #[unsafe(method_family = init)]
1057        pub unsafe fn initWithResources(
1058            this: Allocated<Self>,
1059            resources: Option<&NSArray<ProtocolObject<dyn MTLResource>>>,
1060        ) -> Retained<Self>;
1061    );
1062}
1063
1064/// Methods declared on superclass `NSObject`.
1065#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1066impl MPSRNNRecurrentImageState {
1067    extern_methods!(
1068        #[unsafe(method(new))]
1069        #[unsafe(method_family = new)]
1070        pub unsafe fn new() -> Retained<Self>;
1071    );
1072}
1073
1074extern_class!(
1075    /// Dependencies: This depends on Metal.framework
1076    ///
1077    /// The MPSRNNImageInferenceLayer specifies a recurrent neural network layer for inference on MPSImages.
1078    /// Currently two types of recurrent layers are supported: ones that operate with convolutions on
1079    /// images:
1080    /// MPSRNNImageInferenceLayerand one that operates on matrices:
1081    /// MPSRNNMatrixInferenceLayer.The former can be often used to implement the latter by using 1x1-images, but due to
1082    /// image size restrictions and performance, it is advisable to use
1083    /// MPSRNNMatrixInferenceLayerfor
1084    /// linear recurrent layers.
1085    /// A MPSRNNImageInferenceLayer is initialized using a
1086    /// MPSRNNLayerDescriptor,which further specifies the
1087    /// recurrent network layer, or an array of
1088    /// MPSRNNLayerDescriptors,which specifies a stack
1089    /// of recurrent layers, that can operate in parallel a subset of the inputs in a sequence of inputs and
1090    /// recurrent outputs. Note that currently stacks with bidirectionally traversing encode functions do not support starting
1091    /// from a previous set of recurrent states, but this can be achieved quite easily by defining two separate
1092    /// unidirectional stacks of layers, and running the same input sequence on them separately (one forwards and one backwards)
1093    /// and ultimately combining the two result sequences as desired with auxiliary functions.
1094    ///
1095    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnimageinferencelayer?language=objc)
1096    #[unsafe(super(MPSCNNKernel, MPSKernel, NSObject))]
1097    #[derive(Debug, PartialEq, Eq, Hash)]
1098    #[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1099    pub struct MPSRNNImageInferenceLayer;
1100);
1101
1102#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1103extern_conformance!(
1104    unsafe impl NSCoding for MPSRNNImageInferenceLayer {}
1105);
1106
1107#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1108extern_conformance!(
1109    unsafe impl NSCopying for MPSRNNImageInferenceLayer {}
1110);
1111
1112#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1113unsafe impl CopyingHelper for MPSRNNImageInferenceLayer {
1114    type Result = Self;
1115}
1116
1117#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1118extern_conformance!(
1119    unsafe impl NSObjectProtocol for MPSRNNImageInferenceLayer {}
1120);
1121
1122#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1123extern_conformance!(
1124    unsafe impl NSSecureCoding for MPSRNNImageInferenceLayer {}
1125);
1126
1127#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1128impl MPSRNNImageInferenceLayer {
1129    extern_methods!(
1130        /// The number of feature channels per pixel in the input image.
1131        #[unsafe(method(inputFeatureChannels))]
1132        #[unsafe(method_family = none)]
1133        pub unsafe fn inputFeatureChannels(&self) -> NSUInteger;
1134
1135        /// The number of feature channels per pixel in the output image.
1136        #[unsafe(method(outputFeatureChannels))]
1137        #[unsafe(method_family = none)]
1138        pub unsafe fn outputFeatureChannels(&self) -> NSUInteger;
1139
1140        /// Number of layers in the filter-stack. This will be one when using initWithDevice:rnnDescriptor to initialize
1141        /// this filter and the number of entries in the array 'rnnDescriptors' when initializing this filter with
1142        /// initWithDevice:rnnDescriptors.
1143        #[unsafe(method(numberOfLayers))]
1144        #[unsafe(method_family = none)]
1145        pub unsafe fn numberOfLayers(&self) -> NSUInteger;
1146
1147        /// How output states from
1148        /// encodeSequenceToCommandBufferare constructed.
1149        /// Defaults to NO. For reference
1150        ///
1151        /// See: MPSState.
1152        #[unsafe(method(recurrentOutputIsTemporary))]
1153        #[unsafe(method_family = none)]
1154        pub unsafe fn recurrentOutputIsTemporary(&self) -> bool;
1155
1156        /// Setter for [`recurrentOutputIsTemporary`][Self::recurrentOutputIsTemporary].
1157        #[unsafe(method(setRecurrentOutputIsTemporary:))]
1158        #[unsafe(method_family = none)]
1159        pub unsafe fn setRecurrentOutputIsTemporary(&self, recurrent_output_is_temporary: bool);
1160
1161        /// If YES then calls to
1162        /// encodeSequenceToCommandBufferreturn every recurrent state
1163        /// in the array: recurrentOutputStates.
1164        /// Defaults to NO.
1165        #[unsafe(method(storeAllIntermediateStates))]
1166        #[unsafe(method_family = none)]
1167        pub unsafe fn storeAllIntermediateStates(&self) -> bool;
1168
1169        /// Setter for [`storeAllIntermediateStates`][Self::storeAllIntermediateStates].
1170        #[unsafe(method(setStoreAllIntermediateStates:))]
1171        #[unsafe(method_family = none)]
1172        pub unsafe fn setStoreAllIntermediateStates(&self, store_all_intermediate_states: bool);
1173
1174        /// Defines how to combine the output-results, when encoding bidirectional layers using
1175        /// encodeBidirectionalSequenceToCommandBuffer.Defaults to
1176        /// MPSRNNBidirectionalCombineModeNone.
1177        #[unsafe(method(bidirectionalCombineMode))]
1178        #[unsafe(method_family = none)]
1179        pub unsafe fn bidirectionalCombineMode(&self) -> MPSRNNBidirectionalCombineMode;
1180
1181        /// Setter for [`bidirectionalCombineMode`][Self::bidirectionalCombineMode].
1182        #[unsafe(method(setBidirectionalCombineMode:))]
1183        #[unsafe(method_family = none)]
1184        pub unsafe fn setBidirectionalCombineMode(
1185            &self,
1186            bidirectional_combine_mode: MPSRNNBidirectionalCombineMode,
1187        );
1188
1189        /// Initializes a convolutional RNN kernel
1190        ///
1191        /// Parameter `device`: The MTLDevice on which this MPSRNNImageLayer filter will be used
1192        ///
1193        /// Parameter `rnnDescriptor`: The descriptor that defines the RNN layer
1194        ///
1195        /// Returns: A valid MPSRNNImageInferenceLayer object or nil, if failure.
1196        #[unsafe(method(initWithDevice:rnnDescriptor:))]
1197        #[unsafe(method_family = init)]
1198        pub unsafe fn initWithDevice_rnnDescriptor(
1199            this: Allocated<Self>,
1200            device: &ProtocolObject<dyn MTLDevice>,
1201            rnn_descriptor: &MPSRNNDescriptor,
1202        ) -> Retained<Self>;
1203
1204        /// Initializes a kernel that implements a stack of convolutional RNN layers
1205        ///
1206        /// Parameter `device`: The MTLDevice on which this MPSRNNImageLayer filter will be used
1207        ///
1208        /// Parameter `rnnDescriptors`: An array of RNN descriptors that defines a stack of RNN layers, starting at index zero.
1209        /// The number of layers in stack is the number of entries in the array.
1210        /// All entries in the array must be valid MPSRNNDescriptors.
1211        ///
1212        /// Returns: A valid MPSRNNImageInferenceLayer object or nil, if failure.
1213        #[unsafe(method(initWithDevice:rnnDescriptors:))]
1214        #[unsafe(method_family = init)]
1215        pub unsafe fn initWithDevice_rnnDescriptors(
1216            this: Allocated<Self>,
1217            device: &ProtocolObject<dyn MTLDevice>,
1218            rnn_descriptors: &NSArray<MPSRNNDescriptor>,
1219        ) -> Retained<Self>;
1220
1221        #[unsafe(method(initWithDevice:))]
1222        #[unsafe(method_family = init)]
1223        pub unsafe fn initWithDevice(
1224            this: Allocated<Self>,
1225            device: &ProtocolObject<dyn MTLDevice>,
1226        ) -> Retained<Self>;
1227
1228        #[cfg(all(feature = "MPSImage", feature = "MPSState"))]
1229        /// Encode an MPSRNNImageInferenceLayer kernel (stack) for a sequence of inputs into a command buffer.
1230        /// Note that when encoding using this function the
1231        ///
1232        /// See: layerSequenceDirection is ignored and the layer stack operates as
1233        /// if all layers were forward feeding layers. In order to run bidirectional sequences
1234        /// use
1235        /// encodeBidirectionalSequenceToCommandBuffer:sourceSequence:or alternatively run two layer stacks and combine
1236        /// results at the end using utility functions.
1237        ///
1238        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
1239        ///
1240        /// Parameter `sourceImages`: An array of valid MPSImage objects containing the sequence of source images.
1241        ///
1242        /// Parameter `destinationImages`: An array valid MPSImages to be overwritten by result image sequence. destinationImages may not alias sourceImages.
1243        ///
1244        /// Parameter `recurrentInputState`: An optional state containing the output images and memory cells (for LSTMs)
1245        /// of the layer obtained from the previous input images in a sequence of inputs.
1246        /// Has to be the output of a previous call to this function or nil (assumed zero).
1247        /// Note: can be one of the states returned in
1248        /// recurrentOutputStates.
1249        /// Parameter `recurrentOutputStates`: An optional array that will contain the recurrent output states. If nil then
1250        /// the recurrent output state is discarded.
1251        /// If
1252        /// storeAllIntermediateStatesis YES, then all intermediate states of the sequence
1253        /// are returned in the array, the first one corresponding to the first input in the sequence,
1254        /// otherwise only the last recurrent output state is returned.
1255        /// If recurrentOutputIsTemporary is YES and then all returned recurrent states
1256        /// will be temporary.
1257        ///
1258        /// See: MPSState:isTemporary.
1259        /// Example: In order to get a new state one can do the following:
1260        ///
1261        /// ```text
1262        ///                                                       MPSRNNRecurrentImageState* recurrent0 = nil;
1263        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1264        ///                                                                         sourceImage: source0
1265        ///                                                                    destinationImage: destination0
1266        ///                                                                 recurrentInputState: nil
1267        ///                                                                recurrentOutputState: &recurrent0];
1268        /// ```
1269        ///
1270        /// Then use it for the next input in sequence:
1271        ///
1272        /// ```text
1273        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1274        ///                                                                         sourceImage: source1
1275        ///                                                                    destinationImage: destination1
1276        ///                                                                 recurrentInputState: recurrent0
1277        ///                                                                recurrentOutputState: &recurrent0];
1278        /// ```
1279        ///
1280        /// And discard recurrent output of the third input:
1281        ///
1282        /// ```text
1283        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1284        ///                                                                         sourceImage: source2
1285        ///                                                                    destinationImage: destination2
1286        ///                                                                 recurrentInputState: recurrent0
1287        ///                                                                recurrentOutputState: nil];
1288        /// ```
1289        #[unsafe(method(encodeSequenceToCommandBuffer:sourceImages:destinationImages:recurrentInputState:recurrentOutputStates:))]
1290        #[unsafe(method_family = none)]
1291        pub unsafe fn encodeSequenceToCommandBuffer_sourceImages_destinationImages_recurrentInputState_recurrentOutputStates(
1292            &self,
1293            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1294            source_images: &NSArray<MPSImage>,
1295            destination_images: &NSArray<MPSImage>,
1296            recurrent_input_state: Option<&MPSRNNRecurrentImageState>,
1297            recurrent_output_states: Option<&NSMutableArray<MPSRNNRecurrentImageState>>,
1298        );
1299
1300        #[cfg(feature = "MPSImage")]
1301        /// Encode an MPSRNNImageInferenceLayer kernel stack for an input image sequences into a command buffer bidirectionally.
1302        /// The operation proceeds as follows: The first source image x0 is passed through all forward traversing layers in the stack,
1303        /// ie. those that were initialized with MPSRNNSequenceDirectionForward, recurrent input is assumed zero.
1304        /// This produces forward output yf0 and recurrent states hf00, hf01, hf02, ... hf0n, one for each forward layer.
1305        /// Then x1 is passed to forward layers together with recurrent state hf00, hf01, ..., hf0n, which produces yf1, and hf10,...
1306        /// This procedure is iterated until the last image in the input sequence x_(N-1), which produces forward output yf(N-1).
1307        /// The backwards layers iterate the same sequence backwards, starting from input x_(N-1) (recurrent state zero),
1308        /// that produces yb(N-1) and recurrent output hb(N-1)0, hf(N-1)1, ... hb(N-1)m, one for each backwards traversing layer.
1309        /// Then the backwards layers handle input x_(N-2) using recurrent state hb(N-1)0, ..., et cetera, until the
1310        /// first image of the sequence is computed, producing output yb0. The result of the operation is either pair of sequences
1311        /// ({yf0, yf1, ... , yf(N-1)},  {yb0, yb1, ... , yb(N-1)}) or a combined sequence, {(yf0 + yb0), ... , (yf(N-1) + yb(N-1)) },
1312        /// where '+' stands either for sum, or concatenation along feature channels, as specified by
1313        /// bidirectionalCombineMode.
1314        ///
1315        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
1316        ///
1317        /// Parameter `sourceSequence`: An array of valid MPSImage objects containing the source image sequence (x0, x1, ... x_n-1).
1318        ///
1319        /// Parameter `destinationForwardImages`: An array of valid MPSImages to be overwritten by result from forward input images. If bidirectionalCombineMode
1320        /// is either MPSRNNBidirectionalCombineModeAdd or MPSRNNBidirectionalCombineModeConcatenate, then will
1321        /// contain the combined results. destinationForwardImage may not alias with any of the source images.
1322        ///
1323        /// Parameter `destinationBackwardImages`: If bidirectionalCombineMode is MPSRNNBidirectionalCombineModeNone, then must be a valid MPSImage
1324        /// that will be  overwritten by result from backward input image. Otherwise this parameter is ignored
1325        /// and can be nil. destinationBackwardImages may not alias to any of the source images.
1326        #[unsafe(method(encodeBidirectionalSequenceToCommandBuffer:sourceSequence:destinationForwardImages:destinationBackwardImages:))]
1327        #[unsafe(method_family = none)]
1328        pub unsafe fn encodeBidirectionalSequenceToCommandBuffer_sourceSequence_destinationForwardImages_destinationBackwardImages(
1329            &self,
1330            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1331            source_sequence: &NSArray<MPSImage>,
1332            destination_forward_images: &NSArray<MPSImage>,
1333            destination_backward_images: Option<&NSArray<MPSImage>>,
1334        );
1335
1336        /// NSSecureCoding compatability
1337        ///
1338        /// See
1339        /// MPSKernel#initWithCoder.
1340        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSRNNImageInferenceLayer
1341        ///
1342        /// Parameter `device`: The MTLDevice on which to make the MPSRNNImageInferenceLayer
1343        ///
1344        /// Returns: A new MPSRNNImageInferenceLayer object, or nil if failure.
1345        ///
1346        /// # Safety
1347        ///
1348        /// `a_decoder` possibly has further requirements.
1349        #[unsafe(method(initWithCoder:device:))]
1350        #[unsafe(method_family = init)]
1351        pub unsafe fn initWithCoder_device(
1352            this: Allocated<Self>,
1353            a_decoder: &NSCoder,
1354            device: &ProtocolObject<dyn MTLDevice>,
1355        ) -> Option<Retained<Self>>;
1356
1357        /// Make a copy of this kernel for a new device -
1358        ///
1359        /// See: MPSKernel
1360        ///
1361        /// Parameter `zone`: The NSZone in which to allocate the object
1362        ///
1363        /// Parameter `device`: The device for the new MPSKernel. If nil, then use
1364        /// self.device.
1365        ///
1366        /// Returns: a pointer to a copy of this MPSKernel. This will fail, returning
1367        /// nil if the device is not supported. Devices must be
1368        /// MTLFeatureSet_iOS_GPUFamily2_v1 or later.
1369        ///
1370        /// # Safety
1371        ///
1372        /// `zone` must be a valid pointer or null.
1373        #[unsafe(method(copyWithZone:device:))]
1374        #[unsafe(method_family = copy)]
1375        pub unsafe fn copyWithZone_device(
1376            &self,
1377            zone: *mut NSZone,
1378            device: Option<&ProtocolObject<dyn MTLDevice>>,
1379        ) -> Retained<Self>;
1380    );
1381}
1382
1383/// Methods declared on superclass `MPSKernel`.
1384#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1385impl MPSRNNImageInferenceLayer {
1386    extern_methods!(
1387        /// Called by NSCoder to decode MPSKernels
1388        ///
1389        /// This isn't the right interface to decode a MPSKernel, but
1390        /// it is the one that NSCoder uses. To enable your NSCoder
1391        /// (e.g. NSKeyedUnarchiver) to set which device to use
1392        /// extend the object to adopt the MPSDeviceProvider
1393        /// protocol. Otherwise, the Metal system default device
1394        /// will be used.
1395        ///
1396        /// # Safety
1397        ///
1398        /// `a_decoder` possibly has further requirements.
1399        #[unsafe(method(initWithCoder:))]
1400        #[unsafe(method_family = init)]
1401        pub unsafe fn initWithCoder(
1402            this: Allocated<Self>,
1403            a_decoder: &NSCoder,
1404        ) -> Option<Retained<Self>>;
1405    );
1406}
1407
1408/// Methods declared on superclass `NSObject`.
1409#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1410impl MPSRNNImageInferenceLayer {
1411    extern_methods!(
1412        #[unsafe(method(init))]
1413        #[unsafe(method_family = init)]
1414        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1415
1416        #[unsafe(method(new))]
1417        #[unsafe(method_family = new)]
1418        pub unsafe fn new() -> Retained<Self>;
1419    );
1420}
1421
1422extern_class!(
1423    /// Dependencies: This depends on Metal.framework
1424    ///
1425    /// This class holds all the data that is passed from one sequence iteration of the matrix-based RNN layer to the next.
1426    ///
1427    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnrecurrentmatrixstate?language=objc)
1428    #[unsafe(super(MPSState, NSObject))]
1429    #[derive(Debug, PartialEq, Eq, Hash)]
1430    #[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1431    pub struct MPSRNNRecurrentMatrixState;
1432);
1433
1434#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1435extern_conformance!(
1436    unsafe impl NSObjectProtocol for MPSRNNRecurrentMatrixState {}
1437);
1438
1439#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1440impl MPSRNNRecurrentMatrixState {
1441    extern_methods!(
1442        #[cfg(feature = "MPSMatrix")]
1443        /// Access the stored recurrent matrix data.
1444        ///
1445        /// Parameter `layerIndex`: Index of the layer whose to get - belongs to { 0, 1,...,
1446        ///
1447        /// See: numberOfLayers - 1 }
1448        ///
1449        /// Returns: For valid layerIndex the recurrent output matrix data, otherwise nil.
1450        #[unsafe(method(getRecurrentOutputMatrixForLayerIndex:))]
1451        #[unsafe(method_family = none)]
1452        pub unsafe fn getRecurrentOutputMatrixForLayerIndex(
1453            &self,
1454            layer_index: NSUInteger,
1455        ) -> Option<Retained<MPSMatrix>>;
1456
1457        #[cfg(feature = "MPSMatrix")]
1458        /// Access the stored memory cell matrix data (if present).
1459        ///
1460        /// Parameter `layerIndex`: Index of the layer whose to get - belongs to { 0, 1,...,
1461        ///
1462        /// See: numberOfLayers - 1 }
1463        ///
1464        /// Returns: For valid layerIndex the memory cell image matrix, otherwise nil.
1465        #[unsafe(method(getMemoryCellMatrixForLayerIndex:))]
1466        #[unsafe(method_family = none)]
1467        pub unsafe fn getMemoryCellMatrixForLayerIndex(
1468            &self,
1469            layer_index: NSUInteger,
1470        ) -> Option<Retained<MPSMatrix>>;
1471    );
1472}
1473
1474/// Methods declared on superclass `MPSState`.
1475#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1476impl MPSRNNRecurrentMatrixState {
1477    extern_methods!(
1478        /// Create a MPSState holding a temporary MTLBuffer
1479        ///
1480        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
1481        ///
1482        /// Parameter `bufferSize`: The size of the buffer in bytes
1483        #[unsafe(method(temporaryStateWithCommandBuffer:bufferSize:))]
1484        #[unsafe(method_family = none)]
1485        pub unsafe fn temporaryStateWithCommandBuffer_bufferSize(
1486            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
1487            buffer_size: usize,
1488        ) -> Retained<Self>;
1489
1490        /// Create a MPSState holding a temporary MTLTexture
1491        ///
1492        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
1493        ///
1494        /// Parameter `descriptor`: A descriptor for the new temporary texture
1495        #[unsafe(method(temporaryStateWithCommandBuffer:textureDescriptor:))]
1496        #[unsafe(method_family = none)]
1497        pub unsafe fn temporaryStateWithCommandBuffer_textureDescriptor(
1498            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
1499            descriptor: &MTLTextureDescriptor,
1500        ) -> Retained<Self>;
1501
1502        /// Create a new autoreleased temporary state object without underlying resource
1503        ///
1504        /// Parameter `cmdBuf`: The command buffer with which the temporary resource is associated
1505        #[unsafe(method(temporaryStateWithCommandBuffer:))]
1506        #[unsafe(method_family = none)]
1507        pub unsafe fn temporaryStateWithCommandBuffer(
1508            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
1509        ) -> Retained<Self>;
1510
1511        #[unsafe(method(initWithDevice:bufferSize:))]
1512        #[unsafe(method_family = init)]
1513        pub unsafe fn initWithDevice_bufferSize(
1514            this: Allocated<Self>,
1515            device: &ProtocolObject<dyn MTLDevice>,
1516            buffer_size: usize,
1517        ) -> Retained<Self>;
1518
1519        #[unsafe(method(initWithDevice:textureDescriptor:))]
1520        #[unsafe(method_family = init)]
1521        pub unsafe fn initWithDevice_textureDescriptor(
1522            this: Allocated<Self>,
1523            device: &ProtocolObject<dyn MTLDevice>,
1524            descriptor: &MTLTextureDescriptor,
1525        ) -> Retained<Self>;
1526
1527        /// Create a MPSState with a non-temporary MTLResource
1528        ///
1529        /// Parameter `resource`: A MTLBuffer or MTLTexture. May be nil.
1530        ///
1531        /// # Safety
1532        ///
1533        /// - `resource` may need to be synchronized.
1534        /// - `resource` may be unretained, you must ensure it is kept alive while in use.
1535        #[unsafe(method(initWithResource:))]
1536        #[unsafe(method_family = init)]
1537        pub unsafe fn initWithResource(
1538            this: Allocated<Self>,
1539            resource: Option<&ProtocolObject<dyn MTLResource>>,
1540        ) -> Retained<Self>;
1541
1542        #[unsafe(method(init))]
1543        #[unsafe(method_family = init)]
1544        pub unsafe fn init(this: Allocated<Self>) -> Option<Retained<Self>>;
1545
1546        /// Initialize a non-temporary state to hold a number of textures and buffers
1547        ///
1548        /// The allocation of each resource will be deferred  until it is needed.
1549        /// This occurs when -resource or -resourceAtIndex: is called.
1550        ///
1551        /// Parameter `resourceList`: The list of resources to create.
1552        #[unsafe(method(initWithDevice:resourceList:))]
1553        #[unsafe(method_family = init)]
1554        pub unsafe fn initWithDevice_resourceList(
1555            this: Allocated<Self>,
1556            device: &ProtocolObject<dyn MTLDevice>,
1557            resource_list: &MPSStateResourceList,
1558        ) -> Retained<Self>;
1559
1560        /// Initialize a temporary state to hold a number of textures and buffers
1561        ///
1562        /// The textures occur first in sequence
1563        #[unsafe(method(temporaryStateWithCommandBuffer:resourceList:))]
1564        #[unsafe(method_family = none)]
1565        pub unsafe fn temporaryStateWithCommandBuffer_resourceList(
1566            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1567            resource_list: &MPSStateResourceList,
1568        ) -> Retained<Self>;
1569
1570        /// Create a state object with a list of MTLResources
1571        ///
1572        /// Because MPS prefers deferred allocation of resources
1573        /// your application should use -initWithTextures:bufferSizes:bufferCount:
1574        /// whenever possible. This method is useful for cases when the
1575        /// MTLResources must be initialized by the CPU.
1576        ///
1577        /// # Safety
1578        ///
1579        /// - `resources` generic may need to be synchronized.
1580        /// - `resources` generic may be unretained, you must ensure it is kept alive while in use.
1581        #[unsafe(method(initWithResources:))]
1582        #[unsafe(method_family = init)]
1583        pub unsafe fn initWithResources(
1584            this: Allocated<Self>,
1585            resources: Option<&NSArray<ProtocolObject<dyn MTLResource>>>,
1586        ) -> Retained<Self>;
1587    );
1588}
1589
1590/// Methods declared on superclass `NSObject`.
1591#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
1592impl MPSRNNRecurrentMatrixState {
1593    extern_methods!(
1594        #[unsafe(method(new))]
1595        #[unsafe(method_family = new)]
1596        pub unsafe fn new() -> Retained<Self>;
1597    );
1598}
1599
1600extern_class!(
1601    /// Dependencies: This depends on Metal.framework
1602    ///
1603    /// The MPSRNNMatrixInferenceLayer specifies a recurrent neural network layer for inference on MPSMatrices.
1604    /// Currently two types of recurrent layers are supported: ones that operate with convolutions on
1605    /// images:
1606    /// MPSRNNImageInferenceLayerand one that operates on matrices:
1607    /// MPSRNNMatrixInferenceLayer.The former can be often used to implement the latter by using 1x1-matrices, but due to
1608    /// image size restrictions and performance, it is advisable to use
1609    /// MPSRNNMatrixInferenceLayerfor
1610    /// linear recurrent layers.
1611    /// A MPSRNNMatrixInferenceLayer is initialized using a
1612    /// MPSRNNLayerDescriptor,which further specifies the
1613    /// recurrent network layer, or an array of
1614    /// MPSRNNLayerDescriptors,which specifies a stack
1615    /// of recurrent layers, that can operate in parallel a subset of the inputs in a sequence of inputs and
1616    /// recurrent outputs. Note that currently stacks with bidirectionally traversing encode functions do not support starting
1617    /// from a previous set of recurrent states, but this can be achieved quite easily by defining two separate
1618    /// unidirectional stacks of layers, and running the same input sequence on them separately (one forwards and one backwards)
1619    /// and ultimately combining the two result sequences as desired with auxiliary functions.
1620    /// The input and output vectors in encode calls are stored as rows of the input and output matrices and
1621    /// MPSRNNMatrixInferenceLayer supports matrices with decreasing number of rows: The row-indices identify the different
1622    /// sequences that may be of different lengths - for example if we have three sequences:
1623    /// ( x1, x2, x3 ), ( y1, y2, y3, y4 ) and ( z1, z2 )
1624    /// of vectors xi, yi and zi, then these can be inserted together as a batch to the sequence encoding kernel by
1625    /// using the matrices:
1626    ///
1627    /// ```text
1628    ///                            ( y1 )        ( y2 )        ( y3 )        ( y4 )
1629    ///                       m1 = ( x1 ),  m2 = ( x2 ),  m3 = ( x3 ),  m4 =
1630    ///                            ( z1 )        ( z2 )
1631    /// ```
1632    ///
1633    /// If a recurrent output state is requested then it will contain the state corresponding to last inputs to each
1634    /// sequence and if all the intermediate states are requested (see storeAllIntermediateStates),
1635    /// then the shorter sequences will be propagated by copying the state of the previous output if the
1636    /// input vector is not present in the sequence - in the example above the output states would be:
1637    ///
1638    /// ```text
1639    ///                            ( s_y1 )        ( s_y2 )        ( s_y3 )        ( s_y4 )
1640    ///                       s1 = ( s_x1 ),  s2 = ( s_x2 ),  s3 = ( s_x3 ),  s4 = ( s_x3 )
1641    ///                            ( s_z1 )        ( s_z2 )        ( s_z2 )        ( s_z2 )
1642    /// ```
1643    ///
1644    /// The mathematical operation described in the linear transformations of
1645    /// MPSRNNSingleGateDescriptorMPSLSTMDescriptorand
1646    /// MPSGRUDescriptorare y^T = W x^T
1647    /// <
1648    /// => y = x W^T, where x is the matrix containing
1649    /// the input vectors as rows, y is the matrix containing the output vectors as rows and W is the weight matrix.
1650    ///
1651    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnmatrixinferencelayer?language=objc)
1652    #[unsafe(super(MPSKernel, NSObject))]
1653    #[derive(Debug, PartialEq, Eq, Hash)]
1654    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1655    pub struct MPSRNNMatrixInferenceLayer;
1656);
1657
1658#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1659extern_conformance!(
1660    unsafe impl NSCoding for MPSRNNMatrixInferenceLayer {}
1661);
1662
1663#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1664extern_conformance!(
1665    unsafe impl NSCopying for MPSRNNMatrixInferenceLayer {}
1666);
1667
1668#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1669unsafe impl CopyingHelper for MPSRNNMatrixInferenceLayer {
1670    type Result = Self;
1671}
1672
1673#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1674extern_conformance!(
1675    unsafe impl NSObjectProtocol for MPSRNNMatrixInferenceLayer {}
1676);
1677
1678#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1679extern_conformance!(
1680    unsafe impl NSSecureCoding for MPSRNNMatrixInferenceLayer {}
1681);
1682
1683#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1684impl MPSRNNMatrixInferenceLayer {
1685    extern_methods!(
1686        /// The number of feature channels input vector/matrix.
1687        #[unsafe(method(inputFeatureChannels))]
1688        #[unsafe(method_family = none)]
1689        pub unsafe fn inputFeatureChannels(&self) -> NSUInteger;
1690
1691        /// The number of feature channels in the output vector/matrix.
1692        #[unsafe(method(outputFeatureChannels))]
1693        #[unsafe(method_family = none)]
1694        pub unsafe fn outputFeatureChannels(&self) -> NSUInteger;
1695
1696        /// Number of layers in the filter-stack. This will be one when using initWithDevice:rnnDescriptor to initialize
1697        /// this filter and the number of entries in the array 'rnnDescriptors' when initializing this filter with
1698        /// initWithDevice:rnnDescriptors.
1699        #[unsafe(method(numberOfLayers))]
1700        #[unsafe(method_family = none)]
1701        pub unsafe fn numberOfLayers(&self) -> NSUInteger;
1702
1703        /// How output states from
1704        /// encodeSequenceToCommandBufferare constructed.
1705        /// Defaults to NO. For reference
1706        ///
1707        /// See: MPSState.
1708        #[unsafe(method(recurrentOutputIsTemporary))]
1709        #[unsafe(method_family = none)]
1710        pub unsafe fn recurrentOutputIsTemporary(&self) -> bool;
1711
1712        /// Setter for [`recurrentOutputIsTemporary`][Self::recurrentOutputIsTemporary].
1713        #[unsafe(method(setRecurrentOutputIsTemporary:))]
1714        #[unsafe(method_family = none)]
1715        pub unsafe fn setRecurrentOutputIsTemporary(&self, recurrent_output_is_temporary: bool);
1716
1717        /// If YES then calls to
1718        /// encodeSequenceToCommandBufferreturn every recurrent state
1719        /// in the array: recurrentOutputStates.
1720        /// Defaults to NO.
1721        #[unsafe(method(storeAllIntermediateStates))]
1722        #[unsafe(method_family = none)]
1723        pub unsafe fn storeAllIntermediateStates(&self) -> bool;
1724
1725        /// Setter for [`storeAllIntermediateStates`][Self::storeAllIntermediateStates].
1726        #[unsafe(method(setStoreAllIntermediateStates:))]
1727        #[unsafe(method_family = none)]
1728        pub unsafe fn setStoreAllIntermediateStates(&self, store_all_intermediate_states: bool);
1729
1730        /// Defines how to combine the output-results, when encoding bidirectional layers using
1731        /// encodeBidirectionalSequenceToCommandBuffer.Defaults to
1732        /// MPSRNNBidirectionalCombineModeNone.
1733        #[unsafe(method(bidirectionalCombineMode))]
1734        #[unsafe(method_family = none)]
1735        pub unsafe fn bidirectionalCombineMode(&self) -> MPSRNNBidirectionalCombineMode;
1736
1737        /// Setter for [`bidirectionalCombineMode`][Self::bidirectionalCombineMode].
1738        #[unsafe(method(setBidirectionalCombineMode:))]
1739        #[unsafe(method_family = none)]
1740        pub unsafe fn setBidirectionalCombineMode(
1741            &self,
1742            bidirectional_combine_mode: MPSRNNBidirectionalCombineMode,
1743        );
1744
1745        /// Initializes a linear (fully connected) RNN kernel
1746        ///
1747        /// Parameter `device`: The MTLDevice on which this MPSRNNMatrixLayer filter will be used
1748        ///
1749        /// Parameter `rnnDescriptor`: The descriptor that defines the RNN layer
1750        ///
1751        /// Returns: A valid MPSRNNMatrixInferenceLayer object or nil, if failure.
1752        #[unsafe(method(initWithDevice:rnnDescriptor:))]
1753        #[unsafe(method_family = init)]
1754        pub unsafe fn initWithDevice_rnnDescriptor(
1755            this: Allocated<Self>,
1756            device: &ProtocolObject<dyn MTLDevice>,
1757            rnn_descriptor: &MPSRNNDescriptor,
1758        ) -> Retained<Self>;
1759
1760        /// Initializes a kernel that implements a stack of linear (fully connected) RNN layers
1761        ///
1762        /// Parameter `device`: The MTLDevice on which this MPSRNNMatrixLayer filter will be used
1763        ///
1764        /// Parameter `rnnDescriptors`: An array of RNN descriptors that defines a stack of RNN layers, starting at index zero.
1765        /// The number of layers in stack is the number of entries in the array.
1766        /// All entries in the array must be valid MPSRNNDescriptors.
1767        ///
1768        /// Returns: A valid MPSRNNMatrixInferenceLayer object or nil, if failure.
1769        #[unsafe(method(initWithDevice:rnnDescriptors:))]
1770        #[unsafe(method_family = init)]
1771        pub unsafe fn initWithDevice_rnnDescriptors(
1772            this: Allocated<Self>,
1773            device: &ProtocolObject<dyn MTLDevice>,
1774            rnn_descriptors: &NSArray<MPSRNNDescriptor>,
1775        ) -> Retained<Self>;
1776
1777        #[unsafe(method(initWithDevice:))]
1778        #[unsafe(method_family = init)]
1779        pub unsafe fn initWithDevice(
1780            this: Allocated<Self>,
1781            device: &ProtocolObject<dyn MTLDevice>,
1782        ) -> Retained<Self>;
1783
1784        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
1785        /// Encode an MPSRNNMatrixInferenceLayer kernel (stack) for a sequence of inputs into a command buffer.
1786        /// Note that when encoding using this function the
1787        ///
1788        /// See: layerSequenceDirection is ignored and the layer stack operates as
1789        /// if all layers were forward feeding layers. In order to run bidirectional sequences
1790        /// use
1791        /// encodeBidirectionalSequenceToCommandBuffer:sourceSequence:or alternatively run two layer stacks and combine
1792        /// results at the end using utility functions.
1793        ///
1794        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
1795        ///
1796        /// Parameter `sourceMatrices`: An array of valid MPSMatrix objects containing the sequence of source matrices.
1797        ///
1798        /// Parameter `sourceOffsets`: An array of byte-offsets into the sourceMatrices, if nil zeros are assumed and
1799        /// if not nil must contain offset for every matrix in sourceMatrices.
1800        ///
1801        /// Parameter `destinationMatrices`: An array valid MPSMatrices to be overwritten by result matrix sequence.
1802        /// destinationMatrices may not alias sourceMatrices.
1803        ///
1804        /// Parameter `destinationOffsets`: An array of byte-offsets into the destinationMatrices, if nil zeros are assumed and
1805        /// if not nil must contain offset for every matrix in destinationMatrices.
1806        ///
1807        /// Parameter `recurrentInputState`: An optional state containing the output matrices and memory cells (for LSTMs)
1808        /// of the layer obtained from the previous input matrices in a sequence of inputs.
1809        /// Has to be the output of a previous call to this function or nil (assumed zero).
1810        /// Note: can be one of the states returned in
1811        /// intermediateRecurrentStates.
1812        /// Parameter `recurrentOutputStates`: An optional array that will contain the recurrent output states. If nil then
1813        /// the recurrent output state is discarded.
1814        /// If
1815        /// storeAllIntermediateStatesis YES, then all intermediate states of the sequence
1816        /// are returned in the array, the first one corresponding to the first input in the sequence,
1817        /// otherwise only the last recurrent output state is returned.
1818        /// If recurrentOutputIsTemporary is YES and then all returned recurrent states
1819        /// will be temporary.
1820        ///
1821        /// See: MPSState:isTemporary.
1822        /// Example: In order to get a new state one can do the following:
1823        ///
1824        /// ```text
1825        ///                                                       MPSRNNRecurrentMatrixState* recurrent0 = nil;
1826        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1827        ///                                                                        sourceMatrix: source0
1828        ///                                                                   destinationMatrix: destination0
1829        ///                                                                 recurrentInputState: nil
1830        ///                                                                recurrentOutputState: &recurrent0];
1831        /// ```
1832        ///
1833        /// Then use it for the next input in sequence:
1834        ///
1835        /// ```text
1836        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1837        ///                                                                        sourceMatrix: source1
1838        ///                                                                   destinationMatrix: destination1
1839        ///                                                                 recurrentInputState: recurrent0
1840        ///                                                                recurrentOutputState: &recurrent0];
1841        /// ```
1842        ///
1843        /// And discard recurrent output of the third input:
1844        ///
1845        /// ```text
1846        ///                                                       [filter encodeToCommandBuffer: cmdBuf
1847        ///                                                                        sourceMatrix: source2
1848        ///                                                                   destinationMatrix: destination2
1849        ///                                                                 recurrentInputState: recurrent0
1850        ///                                                                recurrentOutputState: nil];
1851        /// ```
1852        ///
1853        /// # Safety
1854        ///
1855        /// - `source_offsets` must be a valid pointer or null.
1856        /// - `destination_offsets` must be a valid pointer or null.
1857        #[unsafe(method(encodeSequenceToCommandBuffer:sourceMatrices:sourceOffsets:destinationMatrices:destinationOffsets:recurrentInputState:recurrentOutputStates:))]
1858        #[unsafe(method_family = none)]
1859        pub unsafe fn encodeSequenceToCommandBuffer_sourceMatrices_sourceOffsets_destinationMatrices_destinationOffsets_recurrentInputState_recurrentOutputStates(
1860            &self,
1861            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1862            source_matrices: &NSArray<MPSMatrix>,
1863            source_offsets: *mut NSUInteger,
1864            destination_matrices: &NSArray<MPSMatrix>,
1865            destination_offsets: *mut NSUInteger,
1866            recurrent_input_state: Option<&MPSRNNRecurrentMatrixState>,
1867            recurrent_output_states: Option<&NSMutableArray<MPSRNNRecurrentMatrixState>>,
1868        );
1869
1870        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
1871        #[unsafe(method(encodeSequenceToCommandBuffer:sourceMatrices:destinationMatrices:recurrentInputState:recurrentOutputStates:))]
1872        #[unsafe(method_family = none)]
1873        pub unsafe fn encodeSequenceToCommandBuffer_sourceMatrices_destinationMatrices_recurrentInputState_recurrentOutputStates(
1874            &self,
1875            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1876            source_matrices: &NSArray<MPSMatrix>,
1877            destination_matrices: &NSArray<MPSMatrix>,
1878            recurrent_input_state: Option<&MPSRNNRecurrentMatrixState>,
1879            recurrent_output_states: Option<&NSMutableArray<MPSRNNRecurrentMatrixState>>,
1880        );
1881
1882        #[cfg(feature = "MPSMatrix")]
1883        /// Encode an MPSRNNMatrixInferenceLayer kernel stack for an input matrix sequences into a command buffer bidirectionally.
1884        /// The operation proceeds as follows: The first source matrix x0 is passed through all forward traversing layers in the stack,
1885        /// ie. those that were initialized with MPSRNNSequenceDirectionForward, recurrent input is assumed zero.
1886        /// This produces forward output yf0 and recurrent states hf00, hf01, hf02, ... hf0n, one for each forward layer in the stack.
1887        /// Then x1 is passed to forward layers together with recurrent state hf00, hf01, ..., hf0n, which produces yf1, and hf10,...
1888        /// This procedure is iterated until the last matrix in the input sequence x_(N-1), which produces forward output yf(N-1).
1889        /// The backwards layers iterate the same sequence backwards, starting from input x_(N-1) (recurrent state zero),
1890        /// that produces yb(N-1) and recurrent output hb(N-1)0, hf(N-1)1, ... hb(N-1)m, one for each backwards traversing layer.
1891        /// Then the backwards layers handle input x_(N-2) using recurrent state hb(N-1)0, ..., et cetera, until the
1892        /// first matrix of the sequence is computed, producing output yb0. The result of the operation is either pair of sequences
1893        /// ({yf0, yf1, ... , yf(N-1)},  {yb0, yb1, ... , yb(N-1)}) or a combined sequence, {(yf0 + yb0), ... , (yf(N-1) + yb(N-1)) },
1894        /// where '+' stands either for sum, or concatenation along feature channels, as specified by
1895        /// bidirectionalCombineMode.
1896        ///
1897        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
1898        ///
1899        /// Parameter `sourceSequence`: An array of valid MPSMatrix objects containing the source matrix sequence (x0, x1, ... x_n-1).
1900        ///
1901        /// Parameter `destinationForwardMatrices`: An array of valid MPSMatrices to be overwritten by result from forward input matrices. If bidirectionalCombineMode
1902        /// is either MPSRNNBidirectionalCombineModeAdd or MPSRNNBidirectionalCombineModeConcatenate, then will
1903        /// contain the combined results. destinationForwardMatrix may not alias with any of the source matrices.
1904        ///
1905        /// Parameter `destinationBackwardMatrices`: If bidirectionalCombineMode is MPSRNNBidirectionalCombineModeNone, then must be an array of valid MPSMatrices
1906        /// that will be overwritten by result from backward input matrices. Otherwise this parameter is ignored
1907        /// and can be nil. destinationBackwardMatrices may not alias to any of the source matrices.
1908        #[unsafe(method(encodeBidirectionalSequenceToCommandBuffer:sourceSequence:destinationForwardMatrices:destinationBackwardMatrices:))]
1909        #[unsafe(method_family = none)]
1910        pub unsafe fn encodeBidirectionalSequenceToCommandBuffer_sourceSequence_destinationForwardMatrices_destinationBackwardMatrices(
1911            &self,
1912            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1913            source_sequence: &NSArray<MPSMatrix>,
1914            destination_forward_matrices: &NSArray<MPSMatrix>,
1915            destination_backward_matrices: Option<&NSArray<MPSMatrix>>,
1916        );
1917
1918        /// NSSecureCoding compatability
1919        ///
1920        /// See
1921        /// MPSKernel#initWithCoder.
1922        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSRNNMatrixInferenceLayer
1923        ///
1924        /// Parameter `device`: The MTLDevice on which to make the MPSRNNMatrixInferenceLayer
1925        ///
1926        /// Returns: A new MPSRNNMatrixInferenceLayer object, or nil if failure.
1927        ///
1928        /// # Safety
1929        ///
1930        /// `a_decoder` possibly has further requirements.
1931        #[unsafe(method(initWithCoder:device:))]
1932        #[unsafe(method_family = init)]
1933        pub unsafe fn initWithCoder_device(
1934            this: Allocated<Self>,
1935            a_decoder: &NSCoder,
1936            device: &ProtocolObject<dyn MTLDevice>,
1937        ) -> Option<Retained<Self>>;
1938
1939        /// Make a copy of this kernel for a new device -
1940        ///
1941        /// See: MPSKernel
1942        ///
1943        /// Parameter `zone`: The NSZone in which to allocate the object
1944        ///
1945        /// Parameter `device`: The device for the new MPSKernel. If nil, then use
1946        /// self.device.
1947        ///
1948        /// Returns: a pointer to a copy of this MPSKernel. This will fail, returning
1949        /// nil if the device is not supported. Devices must be
1950        /// MTLFeatureSet_iOS_GPUFamily2_v1 or later.
1951        ///
1952        /// # Safety
1953        ///
1954        /// `zone` must be a valid pointer or null.
1955        #[unsafe(method(copyWithZone:device:))]
1956        #[unsafe(method_family = copy)]
1957        pub unsafe fn copyWithZone_device(
1958            &self,
1959            zone: *mut NSZone,
1960            device: Option<&ProtocolObject<dyn MTLDevice>>,
1961        ) -> Retained<Self>;
1962    );
1963}
1964
1965/// Methods declared on superclass `MPSKernel`.
1966#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1967impl MPSRNNMatrixInferenceLayer {
1968    extern_methods!(
1969        /// Called by NSCoder to decode MPSKernels
1970        ///
1971        /// This isn't the right interface to decode a MPSKernel, but
1972        /// it is the one that NSCoder uses. To enable your NSCoder
1973        /// (e.g. NSKeyedUnarchiver) to set which device to use
1974        /// extend the object to adopt the MPSDeviceProvider
1975        /// protocol. Otherwise, the Metal system default device
1976        /// will be used.
1977        ///
1978        /// # Safety
1979        ///
1980        /// `a_decoder` possibly has further requirements.
1981        #[unsafe(method(initWithCoder:))]
1982        #[unsafe(method_family = init)]
1983        pub unsafe fn initWithCoder(
1984            this: Allocated<Self>,
1985            a_decoder: &NSCoder,
1986        ) -> Option<Retained<Self>>;
1987    );
1988}
1989
1990/// Methods declared on superclass `NSObject`.
1991#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1992impl MPSRNNMatrixInferenceLayer {
1993    extern_methods!(
1994        #[unsafe(method(init))]
1995        #[unsafe(method_family = init)]
1996        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1997
1998        #[unsafe(method(new))]
1999        #[unsafe(method_family = new)]
2000        pub unsafe fn new() -> Retained<Self>;
2001    );
2002}
2003
2004extern_class!(
2005    /// Dependencies: This depends on Metal.framework
2006    ///
2007    /// This class holds the data that is passed from the forward pass needed in the backward pass.
2008    ///
2009    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnmatrixtrainingstate?language=objc)
2010    #[unsafe(super(MPSState, NSObject))]
2011    #[derive(Debug, PartialEq, Eq, Hash)]
2012    #[cfg(all(feature = "MPSCore", feature = "MPSState"))]
2013    pub struct MPSRNNMatrixTrainingState;
2014);
2015
2016#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
2017extern_conformance!(
2018    unsafe impl NSObjectProtocol for MPSRNNMatrixTrainingState {}
2019);
2020
2021#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
2022impl MPSRNNMatrixTrainingState {
2023    extern_methods!();
2024}
2025
2026/// Methods declared on superclass `MPSState`.
2027#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
2028impl MPSRNNMatrixTrainingState {
2029    extern_methods!(
2030        /// Create a MPSState holding a temporary MTLBuffer
2031        ///
2032        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
2033        ///
2034        /// Parameter `bufferSize`: The size of the buffer in bytes
2035        #[unsafe(method(temporaryStateWithCommandBuffer:bufferSize:))]
2036        #[unsafe(method_family = none)]
2037        pub unsafe fn temporaryStateWithCommandBuffer_bufferSize(
2038            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
2039            buffer_size: usize,
2040        ) -> Retained<Self>;
2041
2042        /// Create a MPSState holding a temporary MTLTexture
2043        ///
2044        /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
2045        ///
2046        /// Parameter `descriptor`: A descriptor for the new temporary texture
2047        #[unsafe(method(temporaryStateWithCommandBuffer:textureDescriptor:))]
2048        #[unsafe(method_family = none)]
2049        pub unsafe fn temporaryStateWithCommandBuffer_textureDescriptor(
2050            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
2051            descriptor: &MTLTextureDescriptor,
2052        ) -> Retained<Self>;
2053
2054        /// Create a new autoreleased temporary state object without underlying resource
2055        ///
2056        /// Parameter `cmdBuf`: The command buffer with which the temporary resource is associated
2057        #[unsafe(method(temporaryStateWithCommandBuffer:))]
2058        #[unsafe(method_family = none)]
2059        pub unsafe fn temporaryStateWithCommandBuffer(
2060            cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
2061        ) -> Retained<Self>;
2062
2063        #[unsafe(method(initWithDevice:bufferSize:))]
2064        #[unsafe(method_family = init)]
2065        pub unsafe fn initWithDevice_bufferSize(
2066            this: Allocated<Self>,
2067            device: &ProtocolObject<dyn MTLDevice>,
2068            buffer_size: usize,
2069        ) -> Retained<Self>;
2070
2071        #[unsafe(method(initWithDevice:textureDescriptor:))]
2072        #[unsafe(method_family = init)]
2073        pub unsafe fn initWithDevice_textureDescriptor(
2074            this: Allocated<Self>,
2075            device: &ProtocolObject<dyn MTLDevice>,
2076            descriptor: &MTLTextureDescriptor,
2077        ) -> Retained<Self>;
2078
2079        /// Create a MPSState with a non-temporary MTLResource
2080        ///
2081        /// Parameter `resource`: A MTLBuffer or MTLTexture. May be nil.
2082        ///
2083        /// # Safety
2084        ///
2085        /// - `resource` may need to be synchronized.
2086        /// - `resource` may be unretained, you must ensure it is kept alive while in use.
2087        #[unsafe(method(initWithResource:))]
2088        #[unsafe(method_family = init)]
2089        pub unsafe fn initWithResource(
2090            this: Allocated<Self>,
2091            resource: Option<&ProtocolObject<dyn MTLResource>>,
2092        ) -> Retained<Self>;
2093
2094        #[unsafe(method(init))]
2095        #[unsafe(method_family = init)]
2096        pub unsafe fn init(this: Allocated<Self>) -> Option<Retained<Self>>;
2097
2098        /// Initialize a non-temporary state to hold a number of textures and buffers
2099        ///
2100        /// The allocation of each resource will be deferred  until it is needed.
2101        /// This occurs when -resource or -resourceAtIndex: is called.
2102        ///
2103        /// Parameter `resourceList`: The list of resources to create.
2104        #[unsafe(method(initWithDevice:resourceList:))]
2105        #[unsafe(method_family = init)]
2106        pub unsafe fn initWithDevice_resourceList(
2107            this: Allocated<Self>,
2108            device: &ProtocolObject<dyn MTLDevice>,
2109            resource_list: &MPSStateResourceList,
2110        ) -> Retained<Self>;
2111
2112        /// Initialize a temporary state to hold a number of textures and buffers
2113        ///
2114        /// The textures occur first in sequence
2115        #[unsafe(method(temporaryStateWithCommandBuffer:resourceList:))]
2116        #[unsafe(method_family = none)]
2117        pub unsafe fn temporaryStateWithCommandBuffer_resourceList(
2118            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2119            resource_list: &MPSStateResourceList,
2120        ) -> Retained<Self>;
2121
2122        /// Create a state object with a list of MTLResources
2123        ///
2124        /// Because MPS prefers deferred allocation of resources
2125        /// your application should use -initWithTextures:bufferSizes:bufferCount:
2126        /// whenever possible. This method is useful for cases when the
2127        /// MTLResources must be initialized by the CPU.
2128        ///
2129        /// # Safety
2130        ///
2131        /// - `resources` generic may need to be synchronized.
2132        /// - `resources` generic may be unretained, you must ensure it is kept alive while in use.
2133        #[unsafe(method(initWithResources:))]
2134        #[unsafe(method_family = init)]
2135        pub unsafe fn initWithResources(
2136            this: Allocated<Self>,
2137            resources: Option<&NSArray<ProtocolObject<dyn MTLResource>>>,
2138        ) -> Retained<Self>;
2139    );
2140}
2141
2142/// Methods declared on superclass `NSObject`.
2143#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
2144impl MPSRNNMatrixTrainingState {
2145    extern_methods!(
2146        #[unsafe(method(new))]
2147        #[unsafe(method_family = new)]
2148        pub unsafe fn new() -> Retained<Self>;
2149    );
2150}
2151
2152/// [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnmatrixid?language=objc)
2153// NS_ENUM
2154#[repr(transparent)]
2155#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
2156pub struct MPSRNNMatrixId(pub NSUInteger);
2157impl MPSRNNMatrixId {
2158    #[doc(alias = "MPSRNNMatrixIdSingleGateInputWeights")]
2159    pub const SingleGateInputWeights: Self = Self(0);
2160    #[doc(alias = "MPSRNNMatrixIdSingleGateRecurrentWeights")]
2161    pub const SingleGateRecurrentWeights: Self = Self(1);
2162    #[doc(alias = "MPSRNNMatrixIdSingleGateBiasTerms")]
2163    pub const SingleGateBiasTerms: Self = Self(2);
2164    #[doc(alias = "MPSRNNMatrixIdLSTMInputGateInputWeights")]
2165    pub const LSTMInputGateInputWeights: Self = Self(3);
2166    #[doc(alias = "MPSRNNMatrixIdLSTMInputGateRecurrentWeights")]
2167    pub const LSTMInputGateRecurrentWeights: Self = Self(4);
2168    #[doc(alias = "MPSRNNMatrixIdLSTMInputGateMemoryWeights")]
2169    pub const LSTMInputGateMemoryWeights: Self = Self(5);
2170    #[doc(alias = "MPSRNNMatrixIdLSTMInputGateBiasTerms")]
2171    pub const LSTMInputGateBiasTerms: Self = Self(6);
2172    #[doc(alias = "MPSRNNMatrixIdLSTMForgetGateInputWeights")]
2173    pub const LSTMForgetGateInputWeights: Self = Self(7);
2174    #[doc(alias = "MPSRNNMatrixIdLSTMForgetGateRecurrentWeights")]
2175    pub const LSTMForgetGateRecurrentWeights: Self = Self(8);
2176    #[doc(alias = "MPSRNNMatrixIdLSTMForgetGateMemoryWeights")]
2177    pub const LSTMForgetGateMemoryWeights: Self = Self(9);
2178    #[doc(alias = "MPSRNNMatrixIdLSTMForgetGateBiasTerms")]
2179    pub const LSTMForgetGateBiasTerms: Self = Self(10);
2180    #[doc(alias = "MPSRNNMatrixIdLSTMMemoryGateInputWeights")]
2181    pub const LSTMMemoryGateInputWeights: Self = Self(11);
2182    #[doc(alias = "MPSRNNMatrixIdLSTMMemoryGateRecurrentWeights")]
2183    pub const LSTMMemoryGateRecurrentWeights: Self = Self(12);
2184    #[doc(alias = "MPSRNNMatrixIdLSTMMemoryGateMemoryWeights")]
2185    pub const LSTMMemoryGateMemoryWeights: Self = Self(13);
2186    #[doc(alias = "MPSRNNMatrixIdLSTMMemoryGateBiasTerms")]
2187    pub const LSTMMemoryGateBiasTerms: Self = Self(14);
2188    #[doc(alias = "MPSRNNMatrixIdLSTMOutputGateInputWeights")]
2189    pub const LSTMOutputGateInputWeights: Self = Self(15);
2190    #[doc(alias = "MPSRNNMatrixIdLSTMOutputGateRecurrentWeights")]
2191    pub const LSTMOutputGateRecurrentWeights: Self = Self(16);
2192    #[doc(alias = "MPSRNNMatrixIdLSTMOutputGateMemoryWeights")]
2193    pub const LSTMOutputGateMemoryWeights: Self = Self(17);
2194    #[doc(alias = "MPSRNNMatrixIdLSTMOutputGateBiasTerms")]
2195    pub const LSTMOutputGateBiasTerms: Self = Self(18);
2196    #[doc(alias = "MPSRNNMatrixIdGRUInputGateInputWeights")]
2197    pub const GRUInputGateInputWeights: Self = Self(19);
2198    #[doc(alias = "MPSRNNMatrixIdGRUInputGateRecurrentWeights")]
2199    pub const GRUInputGateRecurrentWeights: Self = Self(20);
2200    #[doc(alias = "MPSRNNMatrixIdGRUInputGateBiasTerms")]
2201    pub const GRUInputGateBiasTerms: Self = Self(21);
2202    #[doc(alias = "MPSRNNMatrixIdGRURecurrentGateInputWeights")]
2203    pub const GRURecurrentGateInputWeights: Self = Self(22);
2204    #[doc(alias = "MPSRNNMatrixIdGRURecurrentGateRecurrentWeights")]
2205    pub const GRURecurrentGateRecurrentWeights: Self = Self(23);
2206    #[doc(alias = "MPSRNNMatrixIdGRURecurrentGateBiasTerms")]
2207    pub const GRURecurrentGateBiasTerms: Self = Self(24);
2208    #[doc(alias = "MPSRNNMatrixIdGRUOutputGateInputWeights")]
2209    pub const GRUOutputGateInputWeights: Self = Self(25);
2210    #[doc(alias = "MPSRNNMatrixIdGRUOutputGateRecurrentWeights")]
2211    pub const GRUOutputGateRecurrentWeights: Self = Self(26);
2212    #[doc(alias = "MPSRNNMatrixIdGRUOutputGateInputGateWeights")]
2213    pub const GRUOutputGateInputGateWeights: Self = Self(27);
2214    #[doc(alias = "MPSRNNMatrixIdGRUOutputGateBiasTerms")]
2215    pub const GRUOutputGateBiasTerms: Self = Self(28);
2216    #[doc(alias = "MPSRNNMatrixId_count")]
2217    pub const _count: Self = Self(29);
2218}
2219
2220unsafe impl Encode for MPSRNNMatrixId {
2221    const ENCODING: Encoding = NSUInteger::ENCODING;
2222}
2223
2224unsafe impl RefEncode for MPSRNNMatrixId {
2225    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
2226}
2227
2228extern_class!(
2229    /// Dependencies: This depends on Metal.framework
2230    ///
2231    /// The MPSRNNMatrixTrainingLayer specifies a recurrent neural network layer for training on MPSMatrices.
2232    ///
2233    /// A MPSRNNMatrixTrainingLayer is initialized using a
2234    /// MPSRNNLayerDescriptor,which further specifies the
2235    /// recurrent network layer.
2236    /// The input and output vectors in encode calls are stored as rows of the input and output matrices and
2237    /// MPSRNNMatrixTrainingLayer supports matrices with decreasing number of rows: The row-indices identify the different
2238    /// sequences that may be of different lengths - for example if we have three sequences:
2239    /// ( x1, x2, x3 ), ( y1, y2, y3, y4 ) and ( z1, z2 )
2240    /// of vectors xi, yi and zi, then these can be inserted together as a batch to the sequence encoding kernel by
2241    /// using the matrices:
2242    ///
2243    /// ```text
2244    ///                            ( y1 )        ( y2 )        ( y3 )        ( y4 )
2245    ///                       m1 = ( x1 ),  m2 = ( x2 ),  m3 = ( x3 ),  m4 =
2246    ///                            ( z1 )        ( z2 )
2247    /// ```
2248    ///
2249    /// The gradient computation pass is then achieved by passing the corresponding gradient sequence from the
2250    /// previous layer ( dx1, dx2, dx3 ), ( dy1, dy2, dy3, dy4 ) and ( dz1, dz2 ) as matrices
2251    ///
2252    /// ```text
2253    ///                             ( dy1 )         ( dy2 )         ( dy3 )         ( dy4 )
2254    ///                       dm1 = ( dx1 ),  dm2 = ( dx2 ),  dm3 = ( dx3 ),  dm4 =
2255    ///                             ( dz1 )         ( dz2 )
2256    /// ```
2257    ///
2258    /// The mathematical operation described in the linear transformations of
2259    /// MPSRNNSingleGateDescriptorMPSLSTMDescriptorand
2260    /// MPSGRUDescriptorare y^T = W x^T
2261    /// <
2262    /// => y = x W^T, where x is the matrix containing
2263    /// the input vectors as rows, y is the matrix containing the output vectors as rows and W is the weight matrix.
2264    ///
2265    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsrnnmatrixtraininglayer?language=objc)
2266    #[unsafe(super(MPSKernel, NSObject))]
2267    #[derive(Debug, PartialEq, Eq, Hash)]
2268    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2269    pub struct MPSRNNMatrixTrainingLayer;
2270);
2271
2272#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2273extern_conformance!(
2274    unsafe impl NSCoding for MPSRNNMatrixTrainingLayer {}
2275);
2276
2277#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2278extern_conformance!(
2279    unsafe impl NSCopying for MPSRNNMatrixTrainingLayer {}
2280);
2281
2282#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2283unsafe impl CopyingHelper for MPSRNNMatrixTrainingLayer {
2284    type Result = Self;
2285}
2286
2287#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2288extern_conformance!(
2289    unsafe impl NSObjectProtocol for MPSRNNMatrixTrainingLayer {}
2290);
2291
2292#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2293extern_conformance!(
2294    unsafe impl NSSecureCoding for MPSRNNMatrixTrainingLayer {}
2295);
2296
2297#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2298impl MPSRNNMatrixTrainingLayer {
2299    extern_methods!(
2300        /// The number of feature channels input vector/matrix.
2301        #[unsafe(method(inputFeatureChannels))]
2302        #[unsafe(method_family = none)]
2303        pub unsafe fn inputFeatureChannels(&self) -> NSUInteger;
2304
2305        /// The number of feature channels in the output vector/matrix.
2306        #[unsafe(method(outputFeatureChannels))]
2307        #[unsafe(method_family = none)]
2308        pub unsafe fn outputFeatureChannels(&self) -> NSUInteger;
2309
2310        /// If YES then calls to functions
2311        /// encodeForwardSequenceToCommandBufferand
2312        /// encodeGradientSequenceToCommandBufferreturn every recurrent state
2313        /// in the array: recurrentOutputStates.
2314        /// Defaults to NO.
2315        #[unsafe(method(storeAllIntermediateStates))]
2316        #[unsafe(method_family = none)]
2317        pub unsafe fn storeAllIntermediateStates(&self) -> bool;
2318
2319        /// Setter for [`storeAllIntermediateStates`][Self::storeAllIntermediateStates].
2320        #[unsafe(method(setStoreAllIntermediateStates:))]
2321        #[unsafe(method_family = none)]
2322        pub unsafe fn setStoreAllIntermediateStates(&self, store_all_intermediate_states: bool);
2323
2324        /// How recurrent output states from
2325        /// encodeForwardSequenceToCommandBufferand encodeGradientSequenceToCommandBuffer are constructed.
2326        /// Defaults to NO. For reference
2327        ///
2328        /// See: MPSState.
2329        #[unsafe(method(recurrentOutputIsTemporary))]
2330        #[unsafe(method_family = none)]
2331        pub unsafe fn recurrentOutputIsTemporary(&self) -> bool;
2332
2333        /// Setter for [`recurrentOutputIsTemporary`][Self::recurrentOutputIsTemporary].
2334        #[unsafe(method(setRecurrentOutputIsTemporary:))]
2335        #[unsafe(method_family = none)]
2336        pub unsafe fn setRecurrentOutputIsTemporary(&self, recurrent_output_is_temporary: bool);
2337
2338        /// How training output states from
2339        /// encodeForwardSequenceToCommandBufferare constructed.
2340        /// Defaults to NO. For reference
2341        ///
2342        /// See: MPSState.
2343        #[unsafe(method(trainingStateIsTemporary))]
2344        #[unsafe(method_family = none)]
2345        pub unsafe fn trainingStateIsTemporary(&self) -> bool;
2346
2347        /// Setter for [`trainingStateIsTemporary`][Self::trainingStateIsTemporary].
2348        #[unsafe(method(setTrainingStateIsTemporary:))]
2349        #[unsafe(method_family = none)]
2350        pub unsafe fn setTrainingStateIsTemporary(&self, training_state_is_temporary: bool);
2351
2352        /// If yes then the computed weight gradients are accumulated on top of existing values in
2353        /// calls to the gradient computation functions: encodeGradientSequenceToCommandBuffer.
2354        /// Defaults to NO.
2355        #[unsafe(method(accumulateWeightGradients))]
2356        #[unsafe(method_family = none)]
2357        pub unsafe fn accumulateWeightGradients(&self) -> bool;
2358
2359        /// Setter for [`accumulateWeightGradients`][Self::accumulateWeightGradients].
2360        #[unsafe(method(setAccumulateWeightGradients:))]
2361        #[unsafe(method_family = none)]
2362        pub unsafe fn setAccumulateWeightGradients(&self, accumulate_weight_gradients: bool);
2363
2364        #[cfg(feature = "MPSMatrix")]
2365        /// Initializes a linear (fully connected) RNN kernel for training
2366        ///
2367        /// Parameter `device`: The MTLDevice on which this MPSRNNMatrixLayer filter will be used
2368        ///
2369        /// Parameter `rnnDescriptor`: The descriptor that defines the RNN layer
2370        ///
2371        /// Parameter `trainableWeights`: An array where to store the weights of the layer as MPSMatrices.
2372        /// NOTE: The exact layout and number of matrices may vary between
2373        /// platforms and therefore you should not save out these weights directly,
2374        /// but instead use the function encodeCopyWeightsToCommandBuffer to identify
2375        /// the weights and biases for serialization.
2376        /// Typically you should pass here an initialized but empty NSMutableArray and
2377        /// when this function returns the array will have been populated with the
2378        /// weight matrices needed in the encode-calls, by using initial values from
2379        /// the datasources in rnnDescriptor.
2380        ///
2381        /// Returns: A valid MPSRNNMatrixTrainingLayer object or nil, if failure.
2382        #[unsafe(method(initWithDevice:rnnDescriptor:trainableWeights:))]
2383        #[unsafe(method_family = init)]
2384        pub unsafe fn initWithDevice_rnnDescriptor_trainableWeights(
2385            this: Allocated<Self>,
2386            device: &ProtocolObject<dyn MTLDevice>,
2387            rnn_descriptor: &MPSRNNDescriptor,
2388            trainable_weights: &NSMutableArray<MPSMatrix>,
2389        ) -> Retained<Self>;
2390
2391        #[cfg(all(feature = "MPSCoreTypes", feature = "MPSMatrix"))]
2392        /// Initializes a set of matrices that can be used in training for weight and bias gradient outputs in
2393        ///
2394        /// See: encodeBackwardSequenceToCommandBuffer. Can be also used to easily create auxiliary matrices for example
2395        /// for ADAM and other advanced optimization schemes. The layout and number of matrices is the same as for the outputs of
2396        ///
2397        /// See: initWithDevice, but the data type may differ. NOTE: These matrices cannot be used as weight matrices in the
2398        /// forward and backward encode calls, but matrices from initWithDevice() or createWeightMatrices() should be used instead.
2399        ///
2400        /// Parameter `matricesOut`: An array where the newly created matrices will be stored, will be initialized to zero.
2401        ///
2402        /// Parameter `dataType`: Datatype for the entries - currently MPSDataTypeFloat32 and MPSDataTypeFloat16 are supported.
2403        #[unsafe(method(createWeightGradientMatrices:dataType:))]
2404        #[unsafe(method_family = none)]
2405        pub unsafe fn createWeightGradientMatrices_dataType(
2406            &self,
2407            matrices_out: &NSMutableArray<MPSMatrix>,
2408            data_type: MPSDataType,
2409        );
2410
2411        #[cfg(all(feature = "MPSCoreTypes", feature = "MPSMatrix"))]
2412        /// As
2413        /// createWeightGradientMatrices,but the matrices will be temporary with readCount = 1, which means that they
2414        /// become invalid after the first encode call that reads them. Note also that as the matrices are temporary, their
2415        /// storage mode will be private which means that you can only access the data using a kernel on the GPU.
2416        ///
2417        /// Parameter `matricesOut`: An array where the newly created matrices will be stored, will be initialized to zero.
2418        ///
2419        /// Parameter `dataType`: Datatype for the entries - currently MPSDataTypeFloat32 and MPSDataTypeFloat16 are supported.
2420        ///
2421        /// Parameter `commandBuffer`: The command buffer that the temporary matrices will live on.
2422        #[unsafe(method(createTemporaryWeightGradientMatrices:dataType:commandBuffer:))]
2423        #[unsafe(method_family = none)]
2424        pub unsafe fn createTemporaryWeightGradientMatrices_dataType_commandBuffer(
2425            &self,
2426            matrices_out: &NSMutableArray<MPSMatrix>,
2427            data_type: MPSDataType,
2428            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2429        );
2430
2431        #[cfg(feature = "MPSMatrix")]
2432        /// Initializes a set of matrices that can be used in training for weight and bias matrices in
2433        /// the forward and backward passes. The layout, datatype and number of matrices is the same as for the outputs of
2434        ///
2435        /// See: initWithDevice.
2436        ///
2437        /// Parameter `matricesOut`: An array where the newly created matrices will be stored, will be initialized to zero.
2438        #[unsafe(method(createWeightMatrices:))]
2439        #[unsafe(method_family = none)]
2440        pub unsafe fn createWeightMatrices(&self, matrices_out: &NSMutableArray<MPSMatrix>);
2441
2442        #[unsafe(method(initWithDevice:))]
2443        #[unsafe(method_family = init)]
2444        pub unsafe fn initWithDevice(
2445            this: Allocated<Self>,
2446            device: &ProtocolObject<dyn MTLDevice>,
2447        ) -> Retained<Self>;
2448
2449        #[cfg(feature = "MPSMatrix")]
2450        /// Encode a copy kernel that copies one matrix from the trainable weight set to a matrix with standard layout,
2451        /// where the column index is the input feature channel index (in forward direction) and row index is the output
2452        /// feature channel index.
2453        ///
2454        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
2455        ///
2456        /// Parameter `weights`: An array weights from
2457        ///
2458        /// See: initWithDevice or
2459        ///
2460        /// See: createWeightMatrices.
2461        ///
2462        /// Parameter `matrixId`: Which matrix to copy - has to be a valid Id based on inputs defined in
2463        /// the rnnDescriptor of
2464        ///
2465        /// See: initWithDevice.
2466        ///
2467        /// Parameter `matrix`: The destination or source matrix that is used in the copy.
2468        ///
2469        /// Parameter `copyFromWeightsToMatrix`: If YES then the copy direction is from the set of trainable 'weights' to 'matrix',
2470        /// otherwise the copy is done from 'matrix' to 'weights'.
2471        ///
2472        /// Parameter `matrixOffset`: A (valid) offset into matrix to be applied to the copy operation.
2473        #[unsafe(method(encodeCopyWeightsToCommandBuffer:weights:matrixId:matrix:copyFromWeightsToMatrix:matrixOffset:))]
2474        #[unsafe(method_family = none)]
2475        pub unsafe fn encodeCopyWeightsToCommandBuffer_weights_matrixId_matrix_copyFromWeightsToMatrix_matrixOffset(
2476            &self,
2477            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2478            weights: &NSArray<MPSMatrix>,
2479            matrix_id: MPSRNNMatrixId,
2480            matrix: &MPSMatrix,
2481            copy_from_weights_to_matrix: bool,
2482            matrix_offset: MTLOrigin,
2483        );
2484
2485        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
2486        /// Encode an MPSRNNMatrixTrainingLayer forward pass kernel for a sequence of inputs into a command buffer.
2487        ///
2488        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
2489        ///
2490        /// Parameter `sourceMatrices`: An array of valid MPSMatrix objects containing the sequence of source matrices.
2491        ///
2492        /// Parameter `sourceOffsets`: An array of byte-offsets into the sourceMatrices, if nil zeros are assumed and
2493        /// if not nil must contain offset for every matrix in sourceMatrices.
2494        ///
2495        /// Parameter `destinationMatrices`: An array valid MPSMatrices to be overwritten by result matrix sequence.
2496        /// destinationMatrices may not alias sourceMatrices.
2497        ///
2498        /// Parameter `destinationOffsets`: An array of byte-offsets into the destinationMatrices, if nil zeros are assumed and
2499        /// if not nil must contain offset for every matrix in destinationMatrices.
2500        ///
2501        /// Parameter `trainingStates`: An array containing the training states to be passed to the gradient computation
2502        /// encode function.
2503        ///
2504        /// Parameter `recurrentInputState`: An optional state containing the output matrices and memory cells (for LSTMs)
2505        /// of the layer obtained from the previous input matrices in a sequence of inputs.
2506        /// Has to be the output of a previous call to this function or nil (assumed zero).
2507        ///
2508        /// Parameter `recurrentOutputStates`: An array that will be appended with the recurrent output states. May not be nil.
2509        /// If recurrentOutputIsTemporary is YES and then all returned recurrent states
2510        /// will be temporary.
2511        ///
2512        /// See: MPSState:isTemporary.
2513        ///
2514        /// Parameter `weights`: An array of valid MPSMatrix objects containing the weights, should be the array
2515        /// that was produced either by
2516        ///
2517        /// See: initWithDevice or
2518        ///
2519        /// See: createWeightMatrices.
2520        ///
2521        /// # Safety
2522        ///
2523        /// - `source_offsets` must be a valid pointer or null.
2524        /// - `destination_offsets` must be a valid pointer or null.
2525        #[unsafe(method(encodeForwardSequenceToCommandBuffer:sourceMatrices:sourceOffsets:destinationMatrices:destinationOffsets:trainingStates:recurrentInputState:recurrentOutputStates:weights:))]
2526        #[unsafe(method_family = none)]
2527        pub unsafe fn encodeForwardSequenceToCommandBuffer_sourceMatrices_sourceOffsets_destinationMatrices_destinationOffsets_trainingStates_recurrentInputState_recurrentOutputStates_weights(
2528            &self,
2529            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2530            source_matrices: &NSArray<MPSMatrix>,
2531            source_offsets: *mut NSUInteger,
2532            destination_matrices: &NSArray<MPSMatrix>,
2533            destination_offsets: *mut NSUInteger,
2534            training_states: &NSMutableArray<MPSRNNMatrixTrainingState>,
2535            recurrent_input_state: Option<&MPSRNNRecurrentMatrixState>,
2536            recurrent_output_states: Option<&NSMutableArray<MPSRNNRecurrentMatrixState>>,
2537            weights: &NSArray<MPSMatrix>,
2538        );
2539
2540        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
2541        /// Encode an MPSRNNMatrixTrainingLayer forward pass kernel for a sequence of inputs into a command buffer.
2542        ///
2543        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
2544        ///
2545        /// Parameter `sourceMatrices`: An array of valid MPSMatrix objects containing the sequence of source matrices.
2546        ///
2547        /// Parameter `destinationMatrices`: An array valid MPSMatrices to be overwritten by result matrix sequence.
2548        /// destinationMatrices may not alias sourceMatrices.
2549        ///
2550        /// Parameter `trainingStates`: An array containing the training states to be passed to the gradient computation
2551        /// encode function.
2552        ///
2553        /// Parameter `weights`: An array of valid MPSMatrix objects containing the weights, should be the array
2554        /// that was produced either by
2555        ///
2556        /// See: initWithDevice or
2557        ///
2558        /// See: createWeightMatrices.
2559        #[unsafe(method(encodeForwardSequenceToCommandBuffer:sourceMatrices:destinationMatrices:trainingStates:weights:))]
2560        #[unsafe(method_family = none)]
2561        pub unsafe fn encodeForwardSequenceToCommandBuffer_sourceMatrices_destinationMatrices_trainingStates_weights(
2562            &self,
2563            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2564            source_matrices: &NSArray<MPSMatrix>,
2565            destination_matrices: &NSArray<MPSMatrix>,
2566            training_states: &NSMutableArray<MPSRNNMatrixTrainingState>,
2567            weights: &NSArray<MPSMatrix>,
2568        );
2569
2570        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
2571        /// Encode an MPSRNNMatrixTrainingLayer gradient pass kernel for a sequence of input gradients into a command buffer.
2572        /// NOTE: The time sequence indexing follows the array indexing in the inputs: sourceGradients[0] has to contain the
2573        /// gradients corresponding to the first matrix in the forward pass corresponding to the current subsequence, which is
2574        /// typically sourceMatrices[0].
2575        ///
2576        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
2577        ///
2578        /// Parameter `forwardSources`: An array of MPSMatrix objects containing the sequence of source matrices of the forward pass.
2579        ///
2580        /// Parameter `forwardSourceOffsets`: An array of byte-offsets into the forwardSources, if nil zeros are assumed and
2581        /// if not nil must contain offset for every matrix in forwardSources.
2582        ///
2583        /// Parameter `sourceGradients`: An array of valid MPSMatrix objects containing the sequence of source gradient matrices.
2584        ///
2585        /// Parameter `sourceGradientOffsets`: An array of byte-offsets into the sourceGradients, if nil zeros are assumed and
2586        /// if not nil must contain offset for every matrix in sourceGradients.
2587        ///
2588        /// Parameter `destinationGradients`: An array valid MPSMatrix objects that will receive the backpropagated gradients, may be
2589        /// nil if not needed (for example first layer in graph).
2590        ///
2591        /// Parameter `destinationOffsets`: An array of byte-offsets into the destinationGradients, if nil zeros are assumed and
2592        /// if not nil must contain offset for every matrix in destinationGradients.
2593        ///
2594        /// Parameter `weightGradients`: An array of valid MPSMatrix objects that will receive the gradient wrt. weights and
2595        /// biases of the layer - should be the array that was produced either
2596        /// by
2597        ///
2598        /// See: initWithDevice or
2599        ///
2600        /// See: createWeightMatrices. May be nil in which case
2601        /// the gradients for the weights are not computed.
2602        ///
2603        /// Parameter `trainingStates`: An array containing the training states from the forward pass - the array must contain
2604        /// the states corresponding to the input gradients is sourceGradients.
2605        ///
2606        /// Parameter `recurrentInputState`: An optional state containing the output matrices and memory cells (for LSTMs)
2607        /// of the layer obtained from the previous input gradients in a sequence of inputs.
2608        /// Has to be the output of a previous call to this function or nil (assumed zero).
2609        ///
2610        /// Parameter `recurrentOutputStates`: An array that will be appended with the recurrent output states. Can be nil.
2611        /// If recurrentOutputIsTemporary is YES and then all returned recurrent states
2612        /// will be temporary.
2613        ///
2614        /// See: MPSState:isTemporary.
2615        ///
2616        /// Parameter `weights`: An array of valid MPSMatrix objects containing the weights, should be the array
2617        /// that was produced either by
2618        ///
2619        /// See: initWithDevice or
2620        ///
2621        /// See: createWeightMatrices.
2622        ///
2623        /// # Safety
2624        ///
2625        /// - `forward_source_offsets` must be a valid pointer or null.
2626        /// - `source_gradient_offsets` must be a valid pointer or null.
2627        /// - `destination_offsets` must be a valid pointer or null.
2628        #[unsafe(method(encodeGradientSequenceToCommandBuffer:forwardSources:forwardSourceOffsets:sourceGradients:sourceGradientOffsets:destinationGradients:destinationOffsets:weightGradients:trainingStates:recurrentInputState:recurrentOutputStates:weights:))]
2629        #[unsafe(method_family = none)]
2630        pub unsafe fn encodeGradientSequenceToCommandBuffer_forwardSources_forwardSourceOffsets_sourceGradients_sourceGradientOffsets_destinationGradients_destinationOffsets_weightGradients_trainingStates_recurrentInputState_recurrentOutputStates_weights(
2631            &self,
2632            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2633            forward_sources: &NSArray<MPSMatrix>,
2634            forward_source_offsets: *mut NSUInteger,
2635            source_gradients: &NSArray<MPSMatrix>,
2636            source_gradient_offsets: *mut NSUInteger,
2637            destination_gradients: Option<&NSArray<MPSMatrix>>,
2638            destination_offsets: *mut NSUInteger,
2639            weight_gradients: Option<&NSArray<MPSMatrix>>,
2640            training_states: &NSArray<MPSRNNMatrixTrainingState>,
2641            recurrent_input_state: Option<&MPSRNNRecurrentMatrixState>,
2642            recurrent_output_states: Option<&NSMutableArray<MPSRNNRecurrentMatrixState>>,
2643            weights: &NSArray<MPSMatrix>,
2644        );
2645
2646        #[cfg(all(feature = "MPSMatrix", feature = "MPSState"))]
2647        /// Encode an MPSRNNMatrixTrainingLayer gradient pass kernel for a sequence of input gradients into a command buffer.
2648        /// NOTE: The time sequence indexing follows the array indexing in the inputs: sourceGradients[0] has to contain the
2649        /// gradients corresponding to the first matrix in the forward pass corresponding to the current subsequence, which is
2650        /// typically sourceMatrices[0].
2651        ///
2652        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded filter
2653        ///
2654        /// Parameter `forwardSources`: An array of MPSMatrix objects containing the sequence of source matrices of the forward pass.
2655        ///
2656        /// Parameter `sourceGradients`: An array of MPSMatrix objects containing the sequence of source gradient matrices.
2657        ///
2658        /// Parameter `destinationGradients`: An array valid MPSMatrix objects that will receive the backpropagated gradients, may be
2659        /// nil if not needed (for example first layer in graph).
2660        ///
2661        /// Parameter `weightGradients`: An array valid MPSMatrix objects that will receive the gradient wrt. weights and
2662        /// biases of the layer - should be the array that was produced either
2663        /// by
2664        ///
2665        /// See: initWithDevice or
2666        ///
2667        /// See: createWeightMatrices. May be nil in which case
2668        /// the gradients for the weights are not computed.
2669        /// NOTE: The weight gradients are accumulated on top of existing values so
2670        ///
2671        ///
2672        /// Parameter `trainingStates`: An array containing the training states from the forward pass - the array must contain
2673        /// the states corresponding to the input gradients is sourceGradients.
2674        ///
2675        /// Parameter `weights`: An array of valid MPSMatrix objects containing the weights, should be the array
2676        /// that was produced either by
2677        ///
2678        /// See: initWithDevice or
2679        ///
2680        /// See: createWeightMatrices.
2681        #[unsafe(method(encodeGradientSequenceToCommandBuffer:forwardSources:sourceGradients:destinationGradients:weightGradients:trainingStates:weights:))]
2682        #[unsafe(method_family = none)]
2683        pub unsafe fn encodeGradientSequenceToCommandBuffer_forwardSources_sourceGradients_destinationGradients_weightGradients_trainingStates_weights(
2684            &self,
2685            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
2686            forward_sources: &NSArray<MPSMatrix>,
2687            source_gradients: &NSArray<MPSMatrix>,
2688            destination_gradients: Option<&NSArray<MPSMatrix>>,
2689            weight_gradients: Option<&NSArray<MPSMatrix>>,
2690            training_states: &NSArray<MPSRNNMatrixTrainingState>,
2691            weights: &NSArray<MPSMatrix>,
2692        );
2693
2694        /// NSSecureCoding compatability
2695        ///
2696        /// See
2697        /// MPSKernel#initWithCoder.
2698        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSRNNMatrixTrainingLayer
2699        ///
2700        /// Parameter `device`: The MTLDevice on which to make the MPSRNNMatrixTrainingLayer
2701        ///
2702        /// Returns: A new MPSRNNMatrixTrainingLayer object, or nil if failure.
2703        ///
2704        /// # Safety
2705        ///
2706        /// `a_decoder` possibly has further requirements.
2707        #[unsafe(method(initWithCoder:device:))]
2708        #[unsafe(method_family = init)]
2709        pub unsafe fn initWithCoder_device(
2710            this: Allocated<Self>,
2711            a_decoder: &NSCoder,
2712            device: &ProtocolObject<dyn MTLDevice>,
2713        ) -> Option<Retained<Self>>;
2714
2715        /// Make a copy of this kernel for a new device -
2716        ///
2717        /// See: MPSKernel
2718        ///
2719        /// Parameter `zone`: The NSZone in which to allocate the object
2720        ///
2721        /// Parameter `device`: The device for the new MPSKernel. If nil, then use
2722        /// self.device.
2723        ///
2724        /// Returns: a pointer to a copy of this MPSKernel. This will fail, returning
2725        /// nil if the device is not supported. Devices must be
2726        /// MTLFeatureSet_iOS_GPUFamily2_v1 or later.
2727        ///
2728        /// # Safety
2729        ///
2730        /// `zone` must be a valid pointer or null.
2731        #[unsafe(method(copyWithZone:device:))]
2732        #[unsafe(method_family = copy)]
2733        pub unsafe fn copyWithZone_device(
2734            &self,
2735            zone: *mut NSZone,
2736            device: Option<&ProtocolObject<dyn MTLDevice>>,
2737        ) -> Retained<Self>;
2738    );
2739}
2740
2741/// Methods declared on superclass `MPSKernel`.
2742#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2743impl MPSRNNMatrixTrainingLayer {
2744    extern_methods!(
2745        /// Called by NSCoder to decode MPSKernels
2746        ///
2747        /// This isn't the right interface to decode a MPSKernel, but
2748        /// it is the one that NSCoder uses. To enable your NSCoder
2749        /// (e.g. NSKeyedUnarchiver) to set which device to use
2750        /// extend the object to adopt the MPSDeviceProvider
2751        /// protocol. Otherwise, the Metal system default device
2752        /// will be used.
2753        ///
2754        /// # Safety
2755        ///
2756        /// `a_decoder` possibly has further requirements.
2757        #[unsafe(method(initWithCoder:))]
2758        #[unsafe(method_family = init)]
2759        pub unsafe fn initWithCoder(
2760            this: Allocated<Self>,
2761            a_decoder: &NSCoder,
2762        ) -> Option<Retained<Self>>;
2763    );
2764}
2765
2766/// Methods declared on superclass `NSObject`.
2767#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
2768impl MPSRNNMatrixTrainingLayer {
2769    extern_methods!(
2770        #[unsafe(method(init))]
2771        #[unsafe(method_family = init)]
2772        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
2773
2774        #[unsafe(method(new))]
2775        #[unsafe(method_family = new)]
2776        pub unsafe fn new() -> Retained<Self>;
2777    );
2778}