objc2_metal_performance_shaders/generated/MPSNeuralNetwork/
MPSNNOptimizers.rs

1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5use objc2::__framework_prelude::*;
6use objc2_foundation::*;
7use objc2_metal::*;
8
9use crate::*;
10
11/// [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnregularizationtype?language=objc)
12// NS_ENUM
13#[repr(transparent)]
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct MPSNNRegularizationType(pub NSUInteger);
16impl MPSNNRegularizationType {
17    #[doc(alias = "MPSNNRegularizationTypeNone")]
18    pub const None: Self = Self(0);
19    /// Apply L1 regularization. L1 norm of weights, will be considered to be added to the loss to be minimized.
20    /// the gradient of the regularization loss turns to be 1 scaled with regularizationScale,
21    /// so we add that to the incoming gradient of value.
22    #[doc(alias = "MPSNNRegularizationTypeL1")]
23    pub const L1: Self = Self(1);
24    /// Apply L2 regularization. L2 norm of weights, will be considered to be added to the loss to be minimized.
25    /// the gradient of the regularization loss turns to be the original value scaled with regularizationScale,
26    /// so we add that to the incoming gradient of value.
27    #[doc(alias = "MPSNNRegularizationTypeL2")]
28    pub const L2: Self = Self(2);
29}
30
31unsafe impl Encode for MPSNNRegularizationType {
32    const ENCODING: Encoding = NSUInteger::ENCODING;
33}
34
35unsafe impl RefEncode for MPSNNRegularizationType {
36    const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
37}
38
39extern_class!(
40    /// The MPSNNOptimizerDescriptor base class. Optimizers are generally used to update trainable neural network parameters.
41    /// Users are usually expected to call these MPSKernels from the update methods on their Convolution or BatchNormalization data sources.
42    ///
43    /// Before the gradient is used to update the original value, some preprocessing occurs on each gradient where it is scaled or clipped
44    /// If regularization is chosen the appropriate regularization loss gradient is added to the value gradient.
45    ///
46    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerdescriptor?language=objc)
47    #[unsafe(super(NSObject))]
48    #[derive(Debug, PartialEq, Eq, Hash)]
49    pub struct MPSNNOptimizerDescriptor;
50);
51
52extern_conformance!(
53    unsafe impl NSObjectProtocol for MPSNNOptimizerDescriptor {}
54);
55
56impl MPSNNOptimizerDescriptor {
57    extern_methods!(
58        /// The learningRate at which we update values
59        ///
60        /// The default value is 0.001f
61        #[unsafe(method(learningRate))]
62        #[unsafe(method_family = none)]
63        pub unsafe fn learningRate(&self) -> c_float;
64
65        /// Setter for [`learningRate`][Self::learningRate].
66        #[unsafe(method(setLearningRate:))]
67        #[unsafe(method_family = none)]
68        pub unsafe fn setLearningRate(&self, learning_rate: c_float);
69
70        /// The gradientRescale at which we apply to incoming gradient values
71        ///
72        /// The default value is 1.0
73        #[unsafe(method(gradientRescale))]
74        #[unsafe(method_family = none)]
75        pub unsafe fn gradientRescale(&self) -> c_float;
76
77        /// Setter for [`gradientRescale`][Self::gradientRescale].
78        #[unsafe(method(setGradientRescale:))]
79        #[unsafe(method_family = none)]
80        pub unsafe fn setGradientRescale(&self, gradient_rescale: c_float);
81
82        /// A bool which decides if gradient will be clipped
83        ///
84        /// The default value is NO
85        #[unsafe(method(applyGradientClipping))]
86        #[unsafe(method_family = none)]
87        pub unsafe fn applyGradientClipping(&self) -> bool;
88
89        /// Setter for [`applyGradientClipping`][Self::applyGradientClipping].
90        #[unsafe(method(setApplyGradientClipping:))]
91        #[unsafe(method_family = none)]
92        pub unsafe fn setApplyGradientClipping(&self, apply_gradient_clipping: bool);
93
94        /// The maximum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
95        #[unsafe(method(gradientClipMax))]
96        #[unsafe(method_family = none)]
97        pub unsafe fn gradientClipMax(&self) -> c_float;
98
99        /// Setter for [`gradientClipMax`][Self::gradientClipMax].
100        #[unsafe(method(setGradientClipMax:))]
101        #[unsafe(method_family = none)]
102        pub unsafe fn setGradientClipMax(&self, gradient_clip_max: c_float);
103
104        /// The minimum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
105        #[unsafe(method(gradientClipMin))]
106        #[unsafe(method_family = none)]
107        pub unsafe fn gradientClipMin(&self) -> c_float;
108
109        /// Setter for [`gradientClipMin`][Self::gradientClipMin].
110        #[unsafe(method(setGradientClipMin:))]
111        #[unsafe(method_family = none)]
112        pub unsafe fn setGradientClipMin(&self, gradient_clip_min: c_float);
113
114        /// The regularizationScale at which we apply L1 or L2 regularization, it gets ignored if regularization is None
115        ///
116        /// The default value is 0.0
117        #[unsafe(method(regularizationScale))]
118        #[unsafe(method_family = none)]
119        pub unsafe fn regularizationScale(&self) -> c_float;
120
121        /// Setter for [`regularizationScale`][Self::regularizationScale].
122        #[unsafe(method(setRegularizationScale:))]
123        #[unsafe(method_family = none)]
124        pub unsafe fn setRegularizationScale(&self, regularization_scale: c_float);
125
126        /// The regularizationType which we apply.
127        ///
128        /// The default value is MPSRegularizationTypeNone
129        #[unsafe(method(regularizationType))]
130        #[unsafe(method_family = none)]
131        pub unsafe fn regularizationType(&self) -> MPSNNRegularizationType;
132
133        /// Setter for [`regularizationType`][Self::regularizationType].
134        #[unsafe(method(setRegularizationType:))]
135        #[unsafe(method_family = none)]
136        pub unsafe fn setRegularizationType(&self, regularization_type: MPSNNRegularizationType);
137
138        #[unsafe(method(initWithLearningRate:gradientRescale:regularizationType:regularizationScale:))]
139        #[unsafe(method_family = init)]
140        pub unsafe fn initWithLearningRate_gradientRescale_regularizationType_regularizationScale(
141            this: Allocated<Self>,
142            learning_rate: c_float,
143            gradient_rescale: c_float,
144            regularization_type: MPSNNRegularizationType,
145            regularization_scale: c_float,
146        ) -> Retained<Self>;
147
148        #[unsafe(method(initWithLearningRate:gradientRescale:applyGradientClipping:gradientClipMax:gradientClipMin:regularizationType:regularizationScale:))]
149        #[unsafe(method_family = init)]
150        pub unsafe fn initWithLearningRate_gradientRescale_applyGradientClipping_gradientClipMax_gradientClipMin_regularizationType_regularizationScale(
151            this: Allocated<Self>,
152            learning_rate: c_float,
153            gradient_rescale: c_float,
154            apply_gradient_clipping: bool,
155            gradient_clip_max: c_float,
156            gradient_clip_min: c_float,
157            regularization_type: MPSNNRegularizationType,
158            regularization_scale: c_float,
159        ) -> Retained<Self>;
160
161        #[unsafe(method(optimizerDescriptorWithLearningRate:gradientRescale:regularizationType:regularizationScale:))]
162        #[unsafe(method_family = none)]
163        pub unsafe fn optimizerDescriptorWithLearningRate_gradientRescale_regularizationType_regularizationScale(
164            learning_rate: c_float,
165            gradient_rescale: c_float,
166            regularization_type: MPSNNRegularizationType,
167            regularization_scale: c_float,
168        ) -> Retained<Self>;
169
170        #[unsafe(method(optimizerDescriptorWithLearningRate:gradientRescale:applyGradientClipping:gradientClipMax:gradientClipMin:regularizationType:regularizationScale:))]
171        #[unsafe(method_family = none)]
172        pub unsafe fn optimizerDescriptorWithLearningRate_gradientRescale_applyGradientClipping_gradientClipMax_gradientClipMin_regularizationType_regularizationScale(
173            learning_rate: c_float,
174            gradient_rescale: c_float,
175            apply_gradient_clipping: bool,
176            gradient_clip_max: c_float,
177            gradient_clip_min: c_float,
178            regularization_type: MPSNNRegularizationType,
179            regularization_scale: c_float,
180        ) -> Retained<Self>;
181    );
182}
183
184/// Methods declared on superclass `NSObject`.
185impl MPSNNOptimizerDescriptor {
186    extern_methods!(
187        #[unsafe(method(init))]
188        #[unsafe(method_family = init)]
189        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
190
191        #[unsafe(method(new))]
192        #[unsafe(method_family = new)]
193        pub unsafe fn new() -> Retained<Self>;
194    );
195}
196
197extern_class!(
198    /// The MPSNNOptimizer base class, use one of the child classes, not to be directly used. Optimizers are generally used to update trainable neural network parameters.
199    /// Users are usually expected to call these MPSKernels from the update methods on their Convolution or BatchNormalization data sources.
200    ///
201    /// Before the gradient is used to update the original value, some preprocessing occurs on each gradient where it is scaled or clipped
202    /// If regularization is chosen the appropriate regularization loss gradient is added to the value gradient.
203    ///
204    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizer?language=objc)
205    #[unsafe(super(MPSKernel, NSObject))]
206    #[derive(Debug, PartialEq, Eq, Hash)]
207    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
208    pub struct MPSNNOptimizer;
209);
210
211#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
212extern_conformance!(
213    unsafe impl NSCoding for MPSNNOptimizer {}
214);
215
216#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
217extern_conformance!(
218    unsafe impl NSCopying for MPSNNOptimizer {}
219);
220
221#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
222unsafe impl CopyingHelper for MPSNNOptimizer {
223    type Result = Self;
224}
225
226#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
227extern_conformance!(
228    unsafe impl NSObjectProtocol for MPSNNOptimizer {}
229);
230
231#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
232extern_conformance!(
233    unsafe impl NSSecureCoding for MPSNNOptimizer {}
234);
235
236#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
237impl MPSNNOptimizer {
238    extern_methods!(
239        /// The learningRate at which we update values
240        ///
241        /// The default value is 1e-3
242        #[unsafe(method(learningRate))]
243        #[unsafe(method_family = none)]
244        pub unsafe fn learningRate(&self) -> c_float;
245
246        /// The gradientRescale at which we apply to incoming gradient values
247        ///
248        /// The default value is 1.0
249        #[unsafe(method(gradientRescale))]
250        #[unsafe(method_family = none)]
251        pub unsafe fn gradientRescale(&self) -> c_float;
252
253        /// A bool which decides if gradient will be clipped
254        ///
255        /// The default value is NO
256        #[unsafe(method(applyGradientClipping))]
257        #[unsafe(method_family = none)]
258        pub unsafe fn applyGradientClipping(&self) -> bool;
259
260        /// Setter for [`applyGradientClipping`][Self::applyGradientClipping].
261        #[unsafe(method(setApplyGradientClipping:))]
262        #[unsafe(method_family = none)]
263        pub unsafe fn setApplyGradientClipping(&self, apply_gradient_clipping: bool);
264
265        /// The maximum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
266        #[unsafe(method(gradientClipMax))]
267        #[unsafe(method_family = none)]
268        pub unsafe fn gradientClipMax(&self) -> c_float;
269
270        /// The minimum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
271        #[unsafe(method(gradientClipMin))]
272        #[unsafe(method_family = none)]
273        pub unsafe fn gradientClipMin(&self) -> c_float;
274
275        /// The regularizationScale at which we apply L1 or L2 regularization, it gets ignored if regularization is None
276        ///
277        /// The default value is 0.0
278        #[unsafe(method(regularizationScale))]
279        #[unsafe(method_family = none)]
280        pub unsafe fn regularizationScale(&self) -> c_float;
281
282        /// The regularizationType which we apply.
283        ///
284        /// The default value is MPSRegularizationTypeNone
285        #[unsafe(method(regularizationType))]
286        #[unsafe(method_family = none)]
287        pub unsafe fn regularizationType(&self) -> MPSNNRegularizationType;
288
289        #[unsafe(method(initWithDevice:))]
290        #[unsafe(method_family = init)]
291        pub unsafe fn initWithDevice(
292            this: Allocated<Self>,
293            device: &ProtocolObject<dyn MTLDevice>,
294        ) -> Retained<Self>;
295
296        #[unsafe(method(setLearningRate:))]
297        #[unsafe(method_family = none)]
298        pub unsafe fn setLearningRate(&self, new_learning_rate: c_float);
299    );
300}
301
302/// Methods declared on superclass `MPSKernel`.
303#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
304impl MPSNNOptimizer {
305    extern_methods!(
306        /// Called by NSCoder to decode MPSKernels
307        ///
308        /// This isn't the right interface to decode a MPSKernel, but
309        /// it is the one that NSCoder uses. To enable your NSCoder
310        /// (e.g. NSKeyedUnarchiver) to set which device to use
311        /// extend the object to adopt the MPSDeviceProvider
312        /// protocol. Otherwise, the Metal system default device
313        /// will be used.
314        ///
315        /// # Safety
316        ///
317        /// `a_decoder` possibly has further requirements.
318        #[unsafe(method(initWithCoder:))]
319        #[unsafe(method_family = init)]
320        pub unsafe fn initWithCoder(
321            this: Allocated<Self>,
322            a_decoder: &NSCoder,
323        ) -> Option<Retained<Self>>;
324
325        /// NSSecureCoding compatability
326        ///
327        /// While the standard NSSecureCoding/NSCoding method
328        /// -initWithCoder: should work, since the file can't
329        /// know which device your data is allocated on, we
330        /// have to guess and may guess incorrectly.  To avoid
331        /// that problem, use initWithCoder:device instead.
332        ///
333        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
334        ///
335        /// Parameter `device`: The MTLDevice on which to make the MPSKernel
336        ///
337        /// Returns: A new MPSKernel object, or nil if failure.
338        ///
339        /// # Safety
340        ///
341        /// `a_decoder` possibly has further requirements.
342        #[unsafe(method(initWithCoder:device:))]
343        #[unsafe(method_family = init)]
344        pub unsafe fn initWithCoder_device(
345            this: Allocated<Self>,
346            a_decoder: &NSCoder,
347            device: &ProtocolObject<dyn MTLDevice>,
348        ) -> Option<Retained<Self>>;
349    );
350}
351
352/// Methods declared on superclass `NSObject`.
353#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
354impl MPSNNOptimizer {
355    extern_methods!(
356        #[unsafe(method(init))]
357        #[unsafe(method_family = init)]
358        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
359
360        #[unsafe(method(new))]
361        #[unsafe(method_family = new)]
362        pub unsafe fn new() -> Retained<Self>;
363    );
364}
365
366extern_class!(
367    /// The MPSNNOptimizerStochasticGradientDescent performs a gradient descent with an optional momentum Update
368    /// RMSProp is also known as root mean square propagation.
369    ///
370    /// useNesterov == NO:
371    /// m[t]     = momentumScale * m[t-1] + learningRate * g
372    /// variable = variable - m[t]
373    ///
374    /// useNesterov == YES:
375    /// m[t]     = momentumScale * m[t-1] + g
376    /// variable = variable - (learningRate * (g + m[t] * momentumScale))
377    ///
378    /// where,
379    /// g    is gradient of error wrt variable
380    /// m[t] is momentum of gradients it is a state we keep updating every update iteration
381    ///
382    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerstochasticgradientdescent?language=objc)
383    #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
384    #[derive(Debug, PartialEq, Eq, Hash)]
385    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
386    pub struct MPSNNOptimizerStochasticGradientDescent;
387);
388
389#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
390extern_conformance!(
391    unsafe impl NSCoding for MPSNNOptimizerStochasticGradientDescent {}
392);
393
394#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
395extern_conformance!(
396    unsafe impl NSCopying for MPSNNOptimizerStochasticGradientDescent {}
397);
398
399#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
400unsafe impl CopyingHelper for MPSNNOptimizerStochasticGradientDescent {
401    type Result = Self;
402}
403
404#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
405extern_conformance!(
406    unsafe impl NSObjectProtocol for MPSNNOptimizerStochasticGradientDescent {}
407);
408
409#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
410extern_conformance!(
411    unsafe impl NSSecureCoding for MPSNNOptimizerStochasticGradientDescent {}
412);
413
414#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
415impl MPSNNOptimizerStochasticGradientDescent {
416    extern_methods!(
417        /// The momentumScale at which we update momentum for values array
418        ///
419        /// Default value is 0.0
420        #[unsafe(method(momentumScale))]
421        #[unsafe(method_family = none)]
422        pub unsafe fn momentumScale(&self) -> c_float;
423
424        /// Nesterov momentum is considered an improvement on the usual momentum update
425        ///
426        /// Default value is NO
427        ///
428        /// Note: Maps to old useNestrovMomentum property
429        #[unsafe(method(useNesterovMomentum))]
430        #[unsafe(method_family = none)]
431        pub unsafe fn useNesterovMomentum(&self) -> bool;
432
433        #[unsafe(method(useNestrovMomentum))]
434        #[unsafe(method_family = none)]
435        pub unsafe fn useNestrovMomentum(&self) -> bool;
436
437        #[unsafe(method(initWithDevice:))]
438        #[unsafe(method_family = init)]
439        pub unsafe fn initWithDevice(
440            this: Allocated<Self>,
441            device: &ProtocolObject<dyn MTLDevice>,
442        ) -> Retained<Self>;
443
444        /// Convenience initialization for the momentum update
445        ///
446        ///
447        /// Parameter `device`: The device on which the kernel will execute.
448        ///
449        /// Parameter `learningRate`: The learningRate which will be applied
450        ///
451        ///
452        /// Returns: A valid MPSNNOptimizerStochasticGradientDescent object or nil, if failure.
453        #[unsafe(method(initWithDevice:learningRate:))]
454        #[unsafe(method_family = init)]
455        pub unsafe fn initWithDevice_learningRate(
456            this: Allocated<Self>,
457            device: &ProtocolObject<dyn MTLDevice>,
458            learning_rate: c_float,
459        ) -> Retained<Self>;
460
461        /// Full initialization for the momentum update
462        ///
463        ///
464        /// Parameter `device`: The device on which the kernel will execute.
465        ///
466        /// Parameter `momentumScale`: The momentumScale to update momentum for values array
467        ///
468        /// Parameter `useNesterovMomentum`: Use the Nesterov style momentum update
469        ///
470        /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
471        ///
472        ///
473        /// Returns: A valid MPSNNOptimizerMomentum object or nil, if failure.
474        #[unsafe(method(initWithDevice:momentumScale:useNesterovMomentum:optimizerDescriptor:))]
475        #[unsafe(method_family = init)]
476        pub unsafe fn initWithDevice_momentumScale_useNesterovMomentum_optimizerDescriptor(
477            this: Allocated<Self>,
478            device: &ProtocolObject<dyn MTLDevice>,
479            momentum_scale: c_float,
480            use_nesterov_momentum: bool,
481            optimizer_descriptor: &MPSNNOptimizerDescriptor,
482        ) -> Retained<Self>;
483
484        #[unsafe(method(initWithDevice:momentumScale:useNestrovMomentum:optimizerDescriptor:))]
485        #[unsafe(method_family = init)]
486        pub unsafe fn initWithDevice_momentumScale_useNestrovMomentum_optimizerDescriptor(
487            this: Allocated<Self>,
488            device: &ProtocolObject<dyn MTLDevice>,
489            momentum_scale: c_float,
490            use_nestrov_momentum: bool,
491            optimizer_descriptor: &MPSNNOptimizerDescriptor,
492        ) -> Retained<Self>;
493
494        #[cfg(feature = "MPSMatrix")]
495        /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
496        ///
497        ///
498        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
499        ///
500        /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
501        ///
502        /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
503        ///
504        /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
505        /// be updated and overwritten.
506        ///
507        /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
508        /// be updated and overwritten.
509        ///
510        ///
511        /// The following operations would be applied
512        ///
513        /// useNesterov == NO:
514        /// m[t]     = momentumScale * m[t-1] + learningRate * g
515        /// variable = variable - m[t]
516        ///
517        /// useNesterov == YES:
518        /// m[t]     = momentumScale * m[t-1] + g
519        /// variable = variable - (learningRate * (g + m[t] * momentumScale))
520        ///
521        /// inputMomentumVector == nil
522        /// variable = variable - (learningRate * g)
523        ///
524        /// where,
525        /// g    is gradient of error wrt variable
526        /// m[t] is momentum of gradients it is a state we keep updating every update iteration
527        #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:resultValuesVector:))]
528        #[unsafe(method_family = none)]
529        pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_resultValuesVector(
530            &self,
531            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
532            input_gradient_vector: &MPSVector,
533            input_values_vector: &MPSVector,
534            input_momentum_vector: Option<&MPSVector>,
535            result_values_vector: &MPSVector,
536        );
537
538        #[cfg(feature = "MPSMatrix")]
539        #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:resultValuesMatrix:))]
540        #[unsafe(method_family = none)]
541        pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_resultValuesMatrix(
542            &self,
543            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
544            input_gradient_matrix: &MPSMatrix,
545            input_values_matrix: &MPSMatrix,
546            input_momentum_matrix: Option<&MPSMatrix>,
547            result_values_matrix: &MPSMatrix,
548        );
549
550        #[cfg(all(
551            feature = "MPSCNNConvolution",
552            feature = "MPSMatrix",
553            feature = "MPSNNGradientState",
554            feature = "MPSState"
555        ))]
556        /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
557        ///
558        ///
559        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
560        ///
561        /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
562        ///
563        /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
564        ///
565        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
566        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
567        /// size 1 in which case biases won't be updated
568        ///
569        /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
570        /// be updated and overwritten.
571        ///
572        ///
573        /// The following operations would be applied
574        ///
575        /// useNesterov == NO:
576        /// m[t]     = momentumScale * m[t-1] + learningRate * g
577        /// variable = variable - m[t]
578        ///
579        /// useNesterov == YES:
580        /// m[t]     = momentumScale * m[t-1] + g
581        /// variable = variable - (learningRate * (g + m[t] * momentumScale))
582        ///
583        /// inputMomentumVector == nil
584        /// variable = variable - (learningRate * g)
585        ///
586        /// where,
587        /// g    is gradient of error wrt variable
588        /// m[t] is momentum of gradients it is a state we keep updating every update iteration
589        #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:resultState:))]
590        #[unsafe(method_family = none)]
591        pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_resultState(
592            &self,
593            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
594            convolution_gradient_state: &MPSCNNConvolutionGradientState,
595            convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
596            input_momentum_vectors: Option<&NSArray<MPSVector>>,
597            result_state: &MPSCNNConvolutionWeightsAndBiasesState,
598        );
599
600        #[cfg(all(
601            feature = "MPSCNNBatchNormalization",
602            feature = "MPSCNNNormalizationWeights",
603            feature = "MPSMatrix",
604            feature = "MPSNNGradientState",
605            feature = "MPSState"
606        ))]
607        /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
608        ///
609        ///
610        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
611        ///
612        /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
613        ///
614        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
615        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
616        /// size 1 in which case beta won't be updated
617        ///
618        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
619        /// be updated and overwritten.
620        ///
621        ///
622        /// The following operations would be applied
623        ///
624        /// useNesterov == NO:
625        /// m[t]     = momentumScale * m[t-1] + learningRate * g
626        /// variable = variable - m[t]
627        ///
628        /// useNesterov == YES:
629        /// m[t]     = momentumScale * m[t-1] + g
630        /// variable = variable - (learningRate * (g + m[t] * momentumScale))
631        ///
632        /// inputMomentumVector == nil
633        /// variable = variable - (learningRate * g)
634        ///
635        /// where,
636        /// g    is gradient of error wrt variable
637        /// m[t] is momentum of gradients it is a state we keep updating every update iteration
638        #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:resultState:))]
639        #[unsafe(method_family = none)]
640        pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_resultState(
641            &self,
642            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
643            batch_normalization_state: &MPSCNNBatchNormalizationState,
644            input_momentum_vectors: Option<&NSArray<MPSVector>>,
645            result_state: &MPSCNNNormalizationGammaAndBetaState,
646        );
647
648        #[cfg(all(
649            feature = "MPSCNNBatchNormalization",
650            feature = "MPSCNNNormalizationWeights",
651            feature = "MPSMatrix",
652            feature = "MPSNNGradientState",
653            feature = "MPSState"
654        ))]
655        /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
656        ///
657        ///
658        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
659        ///
660        /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
661        ///
662        /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
663        ///
664        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
665        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
666        /// size 1 in which case beta won't be updated
667        ///
668        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
669        /// be updated and overwritten.
670        ///
671        ///
672        /// The following operations would be applied
673        ///
674        /// useNesterov == NO:
675        /// m[t]     = momentumScale * m[t-1] + learningRate * g
676        /// variable = variable - m[t]
677        ///
678        /// useNesterov == YES:
679        /// m[t]     = momentumScale * m[t-1] + g
680        /// variable = variable - (learningRate * (g + m[t] * momentumScale))
681        ///
682        /// inputMomentumVector == nil
683        /// variable = variable - (learningRate * g)
684        ///
685        /// where,
686        /// g    is gradient of error wrt variable
687        /// m[t] is momentum of gradients it is a state we keep updating every update iteration
688        #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:resultState:))]
689        #[unsafe(method_family = none)]
690        pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_resultState(
691            &self,
692            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
693            batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
694            batch_normalization_source_state: &MPSCNNBatchNormalizationState,
695            input_momentum_vectors: Option<&NSArray<MPSVector>>,
696            result_state: &MPSCNNNormalizationGammaAndBetaState,
697        );
698    );
699}
700
701/// Methods declared on superclass `MPSKernel`.
702#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
703impl MPSNNOptimizerStochasticGradientDescent {
704    extern_methods!(
705        /// Called by NSCoder to decode MPSKernels
706        ///
707        /// This isn't the right interface to decode a MPSKernel, but
708        /// it is the one that NSCoder uses. To enable your NSCoder
709        /// (e.g. NSKeyedUnarchiver) to set which device to use
710        /// extend the object to adopt the MPSDeviceProvider
711        /// protocol. Otherwise, the Metal system default device
712        /// will be used.
713        ///
714        /// # Safety
715        ///
716        /// `a_decoder` possibly has further requirements.
717        #[unsafe(method(initWithCoder:))]
718        #[unsafe(method_family = init)]
719        pub unsafe fn initWithCoder(
720            this: Allocated<Self>,
721            a_decoder: &NSCoder,
722        ) -> Option<Retained<Self>>;
723
724        /// NSSecureCoding compatability
725        ///
726        /// While the standard NSSecureCoding/NSCoding method
727        /// -initWithCoder: should work, since the file can't
728        /// know which device your data is allocated on, we
729        /// have to guess and may guess incorrectly.  To avoid
730        /// that problem, use initWithCoder:device instead.
731        ///
732        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
733        ///
734        /// Parameter `device`: The MTLDevice on which to make the MPSKernel
735        ///
736        /// Returns: A new MPSKernel object, or nil if failure.
737        ///
738        /// # Safety
739        ///
740        /// `a_decoder` possibly has further requirements.
741        #[unsafe(method(initWithCoder:device:))]
742        #[unsafe(method_family = init)]
743        pub unsafe fn initWithCoder_device(
744            this: Allocated<Self>,
745            a_decoder: &NSCoder,
746            device: &ProtocolObject<dyn MTLDevice>,
747        ) -> Option<Retained<Self>>;
748    );
749}
750
751/// Methods declared on superclass `NSObject`.
752#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
753impl MPSNNOptimizerStochasticGradientDescent {
754    extern_methods!(
755        #[unsafe(method(init))]
756        #[unsafe(method_family = init)]
757        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
758
759        #[unsafe(method(new))]
760        #[unsafe(method_family = new)]
761        pub unsafe fn new() -> Retained<Self>;
762    );
763}
764
765extern_class!(
766    /// The MPSNNOptimizerRMSProp performs an RMSProp Update
767    /// RMSProp is also known as root mean square propagation.
768    ///
769    /// s[t]     = decay * s[t-1] + (1 - decay) * (g ^ 2)
770    /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
771    ///
772    /// where,
773    /// g    is gradient of error wrt variable
774    /// s[t] is weighted sum of squares of gradients
775    ///
776    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerrmsprop?language=objc)
777    #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
778    #[derive(Debug, PartialEq, Eq, Hash)]
779    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
780    pub struct MPSNNOptimizerRMSProp;
781);
782
783#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
784extern_conformance!(
785    unsafe impl NSCoding for MPSNNOptimizerRMSProp {}
786);
787
788#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
789extern_conformance!(
790    unsafe impl NSCopying for MPSNNOptimizerRMSProp {}
791);
792
793#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
794unsafe impl CopyingHelper for MPSNNOptimizerRMSProp {
795    type Result = Self;
796}
797
798#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
799extern_conformance!(
800    unsafe impl NSObjectProtocol for MPSNNOptimizerRMSProp {}
801);
802
803#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
804extern_conformance!(
805    unsafe impl NSSecureCoding for MPSNNOptimizerRMSProp {}
806);
807
808#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
809impl MPSNNOptimizerRMSProp {
810    extern_methods!(
811        /// The decay at which we update sumOfSquares
812        ///
813        /// Default value is 0.9
814        #[unsafe(method(decay))]
815        #[unsafe(method_family = none)]
816        pub unsafe fn decay(&self) -> c_double;
817
818        /// The epsilon at which we update values
819        ///
820        /// This value is usually used to ensure to avoid divide by 0, default value is 1e-8
821        #[unsafe(method(epsilon))]
822        #[unsafe(method_family = none)]
823        pub unsafe fn epsilon(&self) -> c_float;
824
825        #[unsafe(method(initWithDevice:))]
826        #[unsafe(method_family = init)]
827        pub unsafe fn initWithDevice(
828            this: Allocated<Self>,
829            device: &ProtocolObject<dyn MTLDevice>,
830        ) -> Retained<Self>;
831
832        /// Convenience initialization for the RMSProp update
833        ///
834        ///
835        /// Parameter `device`: The device on which the kernel will execute.
836        ///
837        /// Parameter `learningRate`: The learningRate which will be applied
838        ///
839        ///
840        /// Returns: A valid MPSNNOptimizerRMSProp object or nil, if failure.
841        #[unsafe(method(initWithDevice:learningRate:))]
842        #[unsafe(method_family = init)]
843        pub unsafe fn initWithDevice_learningRate(
844            this: Allocated<Self>,
845            device: &ProtocolObject<dyn MTLDevice>,
846            learning_rate: c_float,
847        ) -> Retained<Self>;
848
849        /// Full initialization for the rmsProp update
850        ///
851        ///
852        /// Parameter `device`: The device on which the kernel will execute.
853        ///
854        /// Parameter `decay`: The decay to update sumOfSquares
855        ///
856        /// Parameter `epsilon`: The epsilon which will be applied
857        ///
858        /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
859        ///
860        ///
861        /// Returns: A valid MPSNNOptimizerRMSProp object or nil, if failure.
862        #[unsafe(method(initWithDevice:decay:epsilon:optimizerDescriptor:))]
863        #[unsafe(method_family = init)]
864        pub unsafe fn initWithDevice_decay_epsilon_optimizerDescriptor(
865            this: Allocated<Self>,
866            device: &ProtocolObject<dyn MTLDevice>,
867            decay: c_double,
868            epsilon: c_float,
869            optimizer_descriptor: &MPSNNOptimizerDescriptor,
870        ) -> Retained<Self>;
871
872        #[cfg(feature = "MPSMatrix")]
873        /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
874        ///
875        ///
876        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
877        ///
878        /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
879        ///
880        /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
881        ///
882        /// Parameter `inputSumOfSquaresVector`: A valid MPSVector object which specifies the gradient velocity vector which will
883        /// be updated and overwritten.
884        ///
885        /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
886        /// be updated and overwritten.
887        ///
888        ///
889        /// The following operations would be applied
890        ///
891        /// s[t]     = decay * s[t-1] + (1 - decay) * (g ^ 2)
892        /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
893        ///
894        /// where,
895        /// g    is gradient of error wrt variable
896        /// s[t] is weighted sum of squares of gradients
897        #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputSumOfSquaresVector:resultValuesVector:))]
898        #[unsafe(method_family = none)]
899        pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputSumOfSquaresVector_resultValuesVector(
900            &self,
901            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
902            input_gradient_vector: &MPSVector,
903            input_values_vector: &MPSVector,
904            input_sum_of_squares_vector: &MPSVector,
905            result_values_vector: &MPSVector,
906        );
907
908        #[cfg(feature = "MPSMatrix")]
909        #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputSumOfSquaresMatrix:resultValuesMatrix:))]
910        #[unsafe(method_family = none)]
911        pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputSumOfSquaresMatrix_resultValuesMatrix(
912            &self,
913            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
914            input_gradient_matrix: &MPSMatrix,
915            input_values_matrix: &MPSMatrix,
916            input_sum_of_squares_matrix: &MPSMatrix,
917            result_values_matrix: &MPSMatrix,
918        );
919
920        #[cfg(all(
921            feature = "MPSCNNConvolution",
922            feature = "MPSMatrix",
923            feature = "MPSNNGradientState",
924            feature = "MPSState"
925        ))]
926        /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
927        ///
928        ///
929        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
930        ///
931        /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
932        ///
933        /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
934        ///
935        /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
936        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
937        /// size 1 in which case biases won't be updated
938        ///
939        /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
940        /// be updated and overwritten.
941        ///
942        ///
943        /// The following operations would be applied
944        ///
945        /// s[t]     = decay * s[t-1] + (1 - decay) * (g ^ 2)
946        /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
947        ///
948        /// where,
949        /// g    is gradient of error wrt variable
950        /// s[t] is weighted sum of squares of gradients
951        #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputSumOfSquaresVectors:resultState:))]
952        #[unsafe(method_family = none)]
953        pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputSumOfSquaresVectors_resultState(
954            &self,
955            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
956            convolution_gradient_state: &MPSCNNConvolutionGradientState,
957            convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
958            input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
959            result_state: &MPSCNNConvolutionWeightsAndBiasesState,
960        );
961
962        #[cfg(all(
963            feature = "MPSCNNBatchNormalization",
964            feature = "MPSCNNNormalizationWeights",
965            feature = "MPSMatrix",
966            feature = "MPSNNGradientState",
967            feature = "MPSState"
968        ))]
969        /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
970        ///
971        ///
972        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
973        ///
974        /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
975        ///
976        /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
977        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
978        /// size 1 in which case beta won't be updated
979        ///
980        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
981        /// be updated and overwritten.
982        ///
983        ///
984        /// The following operations would be applied
985        ///
986        /// s[t]     = decay * s[t-1] + (1 - decay) * (g ^ 2)
987        /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
988        ///
989        /// where,
990        /// g    is gradient of error wrt variable
991        /// s[t] is weighted sum of squares of gradients
992        #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputSumOfSquaresVectors:resultState:))]
993        #[unsafe(method_family = none)]
994        pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputSumOfSquaresVectors_resultState(
995            &self,
996            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
997            batch_normalization_state: &MPSCNNBatchNormalizationState,
998            input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
999            result_state: &MPSCNNNormalizationGammaAndBetaState,
1000        );
1001
1002        #[cfg(all(
1003            feature = "MPSCNNBatchNormalization",
1004            feature = "MPSCNNNormalizationWeights",
1005            feature = "MPSMatrix",
1006            feature = "MPSNNGradientState",
1007            feature = "MPSState"
1008        ))]
1009        /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
1010        ///
1011        ///
1012        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1013        ///
1014        /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1015        ///
1016        /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1017        ///
1018        /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
1019        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1020        /// size 1 in which case beta won't be updated
1021        ///
1022        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1023        /// be updated and overwritten.
1024        ///
1025        ///
1026        /// The following operations would be applied
1027        ///
1028        /// s[t]     = decay * s[t-1] + (1 - decay) * (g ^ 2)
1029        /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
1030        ///
1031        /// where,
1032        /// g    is gradient of error wrt variable
1033        /// s[t] is weighted sum of squares of gradients
1034        #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputSumOfSquaresVectors:resultState:))]
1035        #[unsafe(method_family = none)]
1036        pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputSumOfSquaresVectors_resultState(
1037            &self,
1038            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1039            batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1040            batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1041            input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
1042            result_state: &MPSCNNNormalizationGammaAndBetaState,
1043        );
1044    );
1045}
1046
1047/// Methods declared on superclass `MPSKernel`.
1048#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1049impl MPSNNOptimizerRMSProp {
1050    extern_methods!(
1051        /// Called by NSCoder to decode MPSKernels
1052        ///
1053        /// This isn't the right interface to decode a MPSKernel, but
1054        /// it is the one that NSCoder uses. To enable your NSCoder
1055        /// (e.g. NSKeyedUnarchiver) to set which device to use
1056        /// extend the object to adopt the MPSDeviceProvider
1057        /// protocol. Otherwise, the Metal system default device
1058        /// will be used.
1059        ///
1060        /// # Safety
1061        ///
1062        /// `a_decoder` possibly has further requirements.
1063        #[unsafe(method(initWithCoder:))]
1064        #[unsafe(method_family = init)]
1065        pub unsafe fn initWithCoder(
1066            this: Allocated<Self>,
1067            a_decoder: &NSCoder,
1068        ) -> Option<Retained<Self>>;
1069
1070        /// NSSecureCoding compatability
1071        ///
1072        /// While the standard NSSecureCoding/NSCoding method
1073        /// -initWithCoder: should work, since the file can't
1074        /// know which device your data is allocated on, we
1075        /// have to guess and may guess incorrectly.  To avoid
1076        /// that problem, use initWithCoder:device instead.
1077        ///
1078        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1079        ///
1080        /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1081        ///
1082        /// Returns: A new MPSKernel object, or nil if failure.
1083        ///
1084        /// # Safety
1085        ///
1086        /// `a_decoder` possibly has further requirements.
1087        #[unsafe(method(initWithCoder:device:))]
1088        #[unsafe(method_family = init)]
1089        pub unsafe fn initWithCoder_device(
1090            this: Allocated<Self>,
1091            a_decoder: &NSCoder,
1092            device: &ProtocolObject<dyn MTLDevice>,
1093        ) -> Option<Retained<Self>>;
1094    );
1095}
1096
1097/// Methods declared on superclass `NSObject`.
1098#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1099impl MPSNNOptimizerRMSProp {
1100    extern_methods!(
1101        #[unsafe(method(init))]
1102        #[unsafe(method_family = init)]
1103        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1104
1105        #[unsafe(method(new))]
1106        #[unsafe(method_family = new)]
1107        pub unsafe fn new() -> Retained<Self>;
1108    );
1109}
1110
1111extern_class!(
1112    /// The MPSNNOptimizerAdam performs an Adam Update
1113    ///
1114    /// Initialization time
1115    /// m[0] = 0 (Initialize initial 1st moment vector aka momentum, user is responsible for this)
1116    /// v[0] = 0 (Initialize initial 2nd moment vector aka velocity, user is responsible for this)
1117    /// t    = 0 (Initialize timestep)
1118    ///
1119    /// https://arxiv.org/abs/1412.6980
1120    ///
1121    /// At update time:
1122    /// t = t + 1
1123    /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1124    ///
1125    /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1126    /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1127    /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1128    ///
1129    /// where,
1130    /// g    is gradient of error wrt variable
1131    /// v[t] is velocity
1132    /// m[t] is momentum
1133    ///
1134    /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizeradam?language=objc)
1135    #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
1136    #[derive(Debug, PartialEq, Eq, Hash)]
1137    #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1138    pub struct MPSNNOptimizerAdam;
1139);
1140
1141#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1142extern_conformance!(
1143    unsafe impl NSCoding for MPSNNOptimizerAdam {}
1144);
1145
1146#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1147extern_conformance!(
1148    unsafe impl NSCopying for MPSNNOptimizerAdam {}
1149);
1150
1151#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1152unsafe impl CopyingHelper for MPSNNOptimizerAdam {
1153    type Result = Self;
1154}
1155
1156#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1157extern_conformance!(
1158    unsafe impl NSObjectProtocol for MPSNNOptimizerAdam {}
1159);
1160
1161#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1162extern_conformance!(
1163    unsafe impl NSSecureCoding for MPSNNOptimizerAdam {}
1164);
1165
1166#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1167impl MPSNNOptimizerAdam {
1168    extern_methods!(
1169        /// The beta1 at which we update values
1170        ///
1171        /// Default value is 0.9
1172        #[unsafe(method(beta1))]
1173        #[unsafe(method_family = none)]
1174        pub unsafe fn beta1(&self) -> c_double;
1175
1176        /// The beta2 at which we update values
1177        ///
1178        /// Default value is 0.999
1179        #[unsafe(method(beta2))]
1180        #[unsafe(method_family = none)]
1181        pub unsafe fn beta2(&self) -> c_double;
1182
1183        /// The epsilon at which we update values
1184        ///
1185        /// This value is usually used to ensure to avoid divide by 0, default value is 1e-8
1186        #[unsafe(method(epsilon))]
1187        #[unsafe(method_family = none)]
1188        pub unsafe fn epsilon(&self) -> c_float;
1189
1190        /// Current timeStep for the update, number of times update has occurred
1191        #[unsafe(method(timeStep))]
1192        #[unsafe(method_family = none)]
1193        pub unsafe fn timeStep(&self) -> NSUInteger;
1194
1195        /// Setter for [`timeStep`][Self::timeStep].
1196        #[unsafe(method(setTimeStep:))]
1197        #[unsafe(method_family = none)]
1198        pub unsafe fn setTimeStep(&self, time_step: NSUInteger);
1199
1200        #[unsafe(method(initWithDevice:))]
1201        #[unsafe(method_family = init)]
1202        pub unsafe fn initWithDevice(
1203            this: Allocated<Self>,
1204            device: &ProtocolObject<dyn MTLDevice>,
1205        ) -> Retained<Self>;
1206
1207        /// Convenience initialization for the adam update
1208        ///
1209        ///
1210        /// Parameter `device`: The device on which the kernel will execute.
1211        ///
1212        /// Parameter `learningRate`: The learningRate at which we will update values
1213        ///
1214        ///
1215        /// Returns: A valid MPSNNOptimizerAdam object or nil, if failure.
1216        #[unsafe(method(initWithDevice:learningRate:))]
1217        #[unsafe(method_family = init)]
1218        pub unsafe fn initWithDevice_learningRate(
1219            this: Allocated<Self>,
1220            device: &ProtocolObject<dyn MTLDevice>,
1221            learning_rate: c_float,
1222        ) -> Retained<Self>;
1223
1224        /// Full initialization for the adam update
1225        ///
1226        ///
1227        /// Parameter `device`: The device on which the kernel will execute.
1228        ///
1229        /// Parameter `beta1`: The beta1 to update values
1230        ///
1231        /// Parameter `beta2`: The beta2 to update values
1232        ///
1233        /// Parameter `epsilon`: The epsilon at which we update values
1234        ///
1235        /// Parameter `timeStep`: The timeStep at which values will start updating
1236        ///
1237        /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
1238        ///
1239        ///
1240        /// Returns: A valid MPSNNOptimizerAdam object or nil, if failure.
1241        #[unsafe(method(initWithDevice:beta1:beta2:epsilon:timeStep:optimizerDescriptor:))]
1242        #[unsafe(method_family = init)]
1243        pub unsafe fn initWithDevice_beta1_beta2_epsilon_timeStep_optimizerDescriptor(
1244            this: Allocated<Self>,
1245            device: &ProtocolObject<dyn MTLDevice>,
1246            beta1: c_double,
1247            beta2: c_double,
1248            epsilon: c_float,
1249            time_step: NSUInteger,
1250            optimizer_descriptor: &MPSNNOptimizerDescriptor,
1251        ) -> Retained<Self>;
1252
1253        #[cfg(feature = "MPSMatrix")]
1254        /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1255        ///
1256        ///
1257        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1258        ///
1259        /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
1260        ///
1261        /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
1262        ///
1263        /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
1264        /// be updated and overwritten.
1265        ///
1266        /// Parameter `inputVelocityVector`: A valid MPSVector object which specifies the gradient velocity vector which will
1267        /// be updated and overwritten.
1268        ///
1269        /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
1270        /// be updated and overwritten.
1271        ///
1272        ///
1273        /// The following operations would be applied
1274        ///
1275        /// t = t + 1
1276        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1277        ///
1278        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1279        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1280        /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1281        #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:inputVelocityVector:resultValuesVector:))]
1282        #[unsafe(method_family = none)]
1283        pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_inputVelocityVector_resultValuesVector(
1284            &self,
1285            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1286            input_gradient_vector: &MPSVector,
1287            input_values_vector: &MPSVector,
1288            input_momentum_vector: &MPSVector,
1289            input_velocity_vector: &MPSVector,
1290            result_values_vector: &MPSVector,
1291        );
1292
1293        #[cfg(feature = "MPSMatrix")]
1294        #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:inputVelocityMatrix:resultValuesMatrix:))]
1295        #[unsafe(method_family = none)]
1296        pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_inputVelocityMatrix_resultValuesMatrix(
1297            &self,
1298            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1299            input_gradient_matrix: &MPSMatrix,
1300            input_values_matrix: &MPSMatrix,
1301            input_momentum_matrix: &MPSMatrix,
1302            input_velocity_matrix: &MPSMatrix,
1303            result_values_matrix: &MPSMatrix,
1304        );
1305
1306        #[cfg(feature = "MPSMatrix")]
1307        /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1308        ///
1309        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1310        ///
1311        /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
1312        ///
1313        /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
1314        ///
1315        /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
1316        /// be updated and overwritten.
1317        ///
1318        /// Parameter `inputVelocityVector`: A valid MPSVector object which specifies the gradient velocity vector which will
1319        /// be updated and overwritten.
1320        ///
1321        /// Parameter `maximumVelocityVector`: A valid MPSVector object which specifies the maximum velocity vector which will
1322        /// be updated and overwritten. May be nil, if nil then normal Adam optimizer behaviour is followed.
1323        ///
1324        /// Parameter `resultValuesVector`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1325        /// be updated and overwritten.
1326        ///
1327        ///
1328        /// The following operations would be applied
1329        /// At update time:
1330        /// t = t + 1
1331        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1332        ///
1333        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1334        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1335        /// maxVel[t] = max(maxVel[t-1],v[t])
1336        /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1337        #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:inputVelocityVector:maximumVelocityVector:resultValuesVector:))]
1338        #[unsafe(method_family = none)]
1339        pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_inputVelocityVector_maximumVelocityVector_resultValuesVector(
1340            &self,
1341            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1342            input_gradient_vector: &MPSVector,
1343            input_values_vector: &MPSVector,
1344            input_momentum_vector: &MPSVector,
1345            input_velocity_vector: &MPSVector,
1346            maximum_velocity_vector: Option<&MPSVector>,
1347            result_values_vector: &MPSVector,
1348        );
1349
1350        #[cfg(feature = "MPSMatrix")]
1351        #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:inputVelocityMatrix:maximumVelocityMatrix:resultValuesMatrix:))]
1352        #[unsafe(method_family = none)]
1353        pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_inputVelocityMatrix_maximumVelocityMatrix_resultValuesMatrix(
1354            &self,
1355            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1356            input_gradient_matrix: &MPSMatrix,
1357            input_values_matrix: &MPSMatrix,
1358            input_momentum_matrix: &MPSMatrix,
1359            input_velocity_matrix: &MPSMatrix,
1360            maximum_velocity_matrix: Option<&MPSMatrix>,
1361            result_values_matrix: &MPSMatrix,
1362        );
1363
1364        #[cfg(all(
1365            feature = "MPSCNNConvolution",
1366            feature = "MPSMatrix",
1367            feature = "MPSNNGradientState",
1368            feature = "MPSState"
1369        ))]
1370        /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1371        ///
1372        ///
1373        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1374        ///
1375        /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
1376        ///
1377        /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
1378        ///
1379        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1380        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1381        /// size 1 in which case biases won't be updated
1382        ///
1383        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1384        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1385        /// size 1 in which case biases won't be updated
1386        ///
1387        /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1388        /// be updated and overwritten.
1389        ///
1390        ///
1391        /// The following operations would be applied
1392        ///
1393        /// t = t + 1
1394        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1395        ///
1396        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1397        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1398        /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1399        #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1400        #[unsafe(method_family = none)]
1401        pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_inputVelocityVectors_resultState(
1402            &self,
1403            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1404            convolution_gradient_state: &MPSCNNConvolutionGradientState,
1405            convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
1406            input_momentum_vectors: Option<&NSArray<MPSVector>>,
1407            input_velocity_vectors: Option<&NSArray<MPSVector>>,
1408            result_state: &MPSCNNConvolutionWeightsAndBiasesState,
1409        );
1410
1411        #[cfg(all(
1412            feature = "MPSCNNConvolution",
1413            feature = "MPSMatrix",
1414            feature = "MPSNNGradientState",
1415            feature = "MPSState"
1416        ))]
1417        /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1418        ///
1419        ///
1420        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1421        ///
1422        /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
1423        ///
1424        /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
1425        ///
1426        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1427        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1428        /// size 1 in which case biases won't be updated
1429        ///
1430        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1431        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1432        /// size 1 in which case biases won't be updated
1433        ///
1434        /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1435        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1436        /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1437        ///
1438        /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1439        /// be updated and overwritten.
1440        ///
1441        ///
1442        /// The following operations would be applied
1443        /// At update time:
1444        /// t = t + 1
1445        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1446        ///
1447        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1448        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1449        /// maxVel[t] = max(maxVel[t-1],v[t])
1450        /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1451        #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1452        #[unsafe(method_family = none)]
1453        pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1454            &self,
1455            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1456            convolution_gradient_state: &MPSCNNConvolutionGradientState,
1457            convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
1458            input_momentum_vectors: &NSArray<MPSVector>,
1459            input_velocity_vectors: &NSArray<MPSVector>,
1460            maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1461            result_state: &MPSCNNConvolutionWeightsAndBiasesState,
1462        );
1463
1464        #[cfg(all(
1465            feature = "MPSCNNBatchNormalization",
1466            feature = "MPSCNNNormalizationWeights",
1467            feature = "MPSMatrix",
1468            feature = "MPSNNGradientState",
1469            feature = "MPSState"
1470        ))]
1471        /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1472        ///
1473        ///
1474        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1475        ///
1476        /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
1477        ///
1478        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1479        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1480        /// size 1 in which case beta won't be updated
1481        ///
1482        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1483        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1484        /// size 1 in which case beta won't be updated
1485        ///
1486        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1487        /// be updated and overwritten.
1488        ///
1489        ///
1490        /// The following operations would be applied
1491        ///
1492        /// t = t + 1
1493        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1494        ///
1495        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1496        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1497        /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1498        #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1499        #[unsafe(method_family = none)]
1500        pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_inputVelocityVectors_resultState(
1501            &self,
1502            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1503            batch_normalization_state: &MPSCNNBatchNormalizationState,
1504            input_momentum_vectors: Option<&NSArray<MPSVector>>,
1505            input_velocity_vectors: Option<&NSArray<MPSVector>>,
1506            result_state: &MPSCNNNormalizationGammaAndBetaState,
1507        );
1508
1509        #[cfg(all(
1510            feature = "MPSCNNBatchNormalization",
1511            feature = "MPSCNNNormalizationWeights",
1512            feature = "MPSMatrix",
1513            feature = "MPSNNGradientState",
1514            feature = "MPSState"
1515        ))]
1516        /// Encode an AMSGrad variant of  MPSNNOptimizerAdam object to a command buffer to perform out of place update
1517        ///
1518        ///
1519        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1520        ///
1521        /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
1522        ///
1523        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1524        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1525        /// size 1 in which case beta won't be updated
1526        ///
1527        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1528        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1529        /// size 1 in which case beta won't be updated
1530        ///
1531        /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1532        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1533        /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1534        ///
1535        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1536        /// be updated and overwritten.
1537        ///
1538        ///
1539        /// The following operations would be applied
1540        /// At update time:
1541        /// t = t + 1
1542        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1543        ///
1544        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1545        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1546        /// maxVel[t] = max(maxVel[t-1],v[t])
1547        /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1548        #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1549        #[unsafe(method_family = none)]
1550        pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1551            &self,
1552            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1553            batch_normalization_state: &MPSCNNBatchNormalizationState,
1554            input_momentum_vectors: &NSArray<MPSVector>,
1555            input_velocity_vectors: &NSArray<MPSVector>,
1556            maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1557            result_state: &MPSCNNNormalizationGammaAndBetaState,
1558        );
1559
1560        #[cfg(all(
1561            feature = "MPSCNNBatchNormalization",
1562            feature = "MPSCNNNormalizationWeights",
1563            feature = "MPSMatrix",
1564            feature = "MPSNNGradientState",
1565            feature = "MPSState"
1566        ))]
1567        /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1568        ///
1569        ///
1570        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1571        ///
1572        /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1573        ///
1574        /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1575        ///
1576        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1577        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1578        /// size 1 in which case beta won't be updated
1579        ///
1580        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1581        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1582        /// size 1 in which case beta won't be updated
1583        ///
1584        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1585        /// be updated and overwritten.
1586        ///
1587        ///
1588        /// The following operations would be applied
1589        ///
1590        /// t = t + 1
1591        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1592        ///
1593        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1594        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1595        /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1596        #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1597        #[unsafe(method_family = none)]
1598        pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_inputVelocityVectors_resultState(
1599            &self,
1600            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1601            batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1602            batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1603            input_momentum_vectors: Option<&NSArray<MPSVector>>,
1604            input_velocity_vectors: Option<&NSArray<MPSVector>>,
1605            result_state: &MPSCNNNormalizationGammaAndBetaState,
1606        );
1607
1608        #[cfg(all(
1609            feature = "MPSCNNBatchNormalization",
1610            feature = "MPSCNNNormalizationWeights",
1611            feature = "MPSMatrix",
1612            feature = "MPSNNGradientState",
1613            feature = "MPSState"
1614        ))]
1615        /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1616        ///
1617        ///
1618        /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1619        ///
1620        /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1621        ///
1622        /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1623        ///
1624        /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1625        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1626        /// size 1 in which case beta won't be updated
1627        ///
1628        /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1629        /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1630        /// size 1 in which case beta won't be updated
1631        ///
1632        /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1633        /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1634        /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1635        ///
1636        /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1637        /// be updated and overwritten.
1638        ///
1639        ///
1640        /// The following operations would be applied
1641        /// At update time:
1642        /// t = t + 1
1643        /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1644        ///
1645        /// m[t]     = beta1 * m[t-1] + (1 - beta1) * g
1646        /// v[t]     = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1647        /// maxVel[t] = max(maxVel[t-1],v[t])
1648        /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1649        #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1650        #[unsafe(method_family = none)]
1651        pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1652            &self,
1653            command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1654            batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1655            batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1656            input_momentum_vectors: &NSArray<MPSVector>,
1657            input_velocity_vectors: &NSArray<MPSVector>,
1658            maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1659            result_state: &MPSCNNNormalizationGammaAndBetaState,
1660        );
1661    );
1662}
1663
1664/// Methods declared on superclass `MPSKernel`.
1665#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1666impl MPSNNOptimizerAdam {
1667    extern_methods!(
1668        /// Called by NSCoder to decode MPSKernels
1669        ///
1670        /// This isn't the right interface to decode a MPSKernel, but
1671        /// it is the one that NSCoder uses. To enable your NSCoder
1672        /// (e.g. NSKeyedUnarchiver) to set which device to use
1673        /// extend the object to adopt the MPSDeviceProvider
1674        /// protocol. Otherwise, the Metal system default device
1675        /// will be used.
1676        ///
1677        /// # Safety
1678        ///
1679        /// `a_decoder` possibly has further requirements.
1680        #[unsafe(method(initWithCoder:))]
1681        #[unsafe(method_family = init)]
1682        pub unsafe fn initWithCoder(
1683            this: Allocated<Self>,
1684            a_decoder: &NSCoder,
1685        ) -> Option<Retained<Self>>;
1686
1687        /// NSSecureCoding compatability
1688        ///
1689        /// While the standard NSSecureCoding/NSCoding method
1690        /// -initWithCoder: should work, since the file can't
1691        /// know which device your data is allocated on, we
1692        /// have to guess and may guess incorrectly.  To avoid
1693        /// that problem, use initWithCoder:device instead.
1694        ///
1695        /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1696        ///
1697        /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1698        ///
1699        /// Returns: A new MPSKernel object, or nil if failure.
1700        ///
1701        /// # Safety
1702        ///
1703        /// `a_decoder` possibly has further requirements.
1704        #[unsafe(method(initWithCoder:device:))]
1705        #[unsafe(method_family = init)]
1706        pub unsafe fn initWithCoder_device(
1707            this: Allocated<Self>,
1708            a_decoder: &NSCoder,
1709            device: &ProtocolObject<dyn MTLDevice>,
1710        ) -> Option<Retained<Self>>;
1711    );
1712}
1713
1714/// Methods declared on superclass `NSObject`.
1715#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1716impl MPSNNOptimizerAdam {
1717    extern_methods!(
1718        #[unsafe(method(init))]
1719        #[unsafe(method_family = init)]
1720        pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1721
1722        #[unsafe(method(new))]
1723        #[unsafe(method_family = new)]
1724        pub unsafe fn new() -> Retained<Self>;
1725    );
1726}