objc2_metal_performance_shaders/generated/MPSNeuralNetwork/MPSNNOptimizers.rs
1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5use objc2::__framework_prelude::*;
6use objc2_foundation::*;
7use objc2_metal::*;
8
9use crate::*;
10
11/// [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnregularizationtype?language=objc)
12// NS_ENUM
13#[repr(transparent)]
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
15pub struct MPSNNRegularizationType(pub NSUInteger);
16impl MPSNNRegularizationType {
17 #[doc(alias = "MPSNNRegularizationTypeNone")]
18 pub const None: Self = Self(0);
19 /// Apply L1 regularization. L1 norm of weights, will be considered to be added to the loss to be minimized.
20 /// the gradient of the regularization loss turns to be 1 scaled with regularizationScale,
21 /// so we add that to the incoming gradient of value.
22 #[doc(alias = "MPSNNRegularizationTypeL1")]
23 pub const L1: Self = Self(1);
24 /// Apply L2 regularization. L2 norm of weights, will be considered to be added to the loss to be minimized.
25 /// the gradient of the regularization loss turns to be the original value scaled with regularizationScale,
26 /// so we add that to the incoming gradient of value.
27 #[doc(alias = "MPSNNRegularizationTypeL2")]
28 pub const L2: Self = Self(2);
29}
30
31unsafe impl Encode for MPSNNRegularizationType {
32 const ENCODING: Encoding = NSUInteger::ENCODING;
33}
34
35unsafe impl RefEncode for MPSNNRegularizationType {
36 const ENCODING_REF: Encoding = Encoding::Pointer(&Self::ENCODING);
37}
38
39extern_class!(
40 /// The MPSNNOptimizerDescriptor base class. Optimizers are generally used to update trainable neural network parameters.
41 /// Users are usually expected to call these MPSKernels from the update methods on their Convolution or BatchNormalization data sources.
42 ///
43 /// Before the gradient is used to update the original value, some preprocessing occurs on each gradient where it is scaled or clipped
44 /// If regularization is chosen the appropriate regularization loss gradient is added to the value gradient.
45 ///
46 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerdescriptor?language=objc)
47 #[unsafe(super(NSObject))]
48 #[derive(Debug, PartialEq, Eq, Hash)]
49 pub struct MPSNNOptimizerDescriptor;
50);
51
52extern_conformance!(
53 unsafe impl NSObjectProtocol for MPSNNOptimizerDescriptor {}
54);
55
56impl MPSNNOptimizerDescriptor {
57 extern_methods!(
58 /// The learningRate at which we update values
59 ///
60 /// The default value is 0.001f
61 #[unsafe(method(learningRate))]
62 #[unsafe(method_family = none)]
63 pub unsafe fn learningRate(&self) -> c_float;
64
65 /// Setter for [`learningRate`][Self::learningRate].
66 #[unsafe(method(setLearningRate:))]
67 #[unsafe(method_family = none)]
68 pub unsafe fn setLearningRate(&self, learning_rate: c_float);
69
70 /// The gradientRescale at which we apply to incoming gradient values
71 ///
72 /// The default value is 1.0
73 #[unsafe(method(gradientRescale))]
74 #[unsafe(method_family = none)]
75 pub unsafe fn gradientRescale(&self) -> c_float;
76
77 /// Setter for [`gradientRescale`][Self::gradientRescale].
78 #[unsafe(method(setGradientRescale:))]
79 #[unsafe(method_family = none)]
80 pub unsafe fn setGradientRescale(&self, gradient_rescale: c_float);
81
82 /// A bool which decides if gradient will be clipped
83 ///
84 /// The default value is NO
85 #[unsafe(method(applyGradientClipping))]
86 #[unsafe(method_family = none)]
87 pub unsafe fn applyGradientClipping(&self) -> bool;
88
89 /// Setter for [`applyGradientClipping`][Self::applyGradientClipping].
90 #[unsafe(method(setApplyGradientClipping:))]
91 #[unsafe(method_family = none)]
92 pub unsafe fn setApplyGradientClipping(&self, apply_gradient_clipping: bool);
93
94 /// The maximum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
95 #[unsafe(method(gradientClipMax))]
96 #[unsafe(method_family = none)]
97 pub unsafe fn gradientClipMax(&self) -> c_float;
98
99 /// Setter for [`gradientClipMax`][Self::gradientClipMax].
100 #[unsafe(method(setGradientClipMax:))]
101 #[unsafe(method_family = none)]
102 pub unsafe fn setGradientClipMax(&self, gradient_clip_max: c_float);
103
104 /// The minimum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
105 #[unsafe(method(gradientClipMin))]
106 #[unsafe(method_family = none)]
107 pub unsafe fn gradientClipMin(&self) -> c_float;
108
109 /// Setter for [`gradientClipMin`][Self::gradientClipMin].
110 #[unsafe(method(setGradientClipMin:))]
111 #[unsafe(method_family = none)]
112 pub unsafe fn setGradientClipMin(&self, gradient_clip_min: c_float);
113
114 /// The regularizationScale at which we apply L1 or L2 regularization, it gets ignored if regularization is None
115 ///
116 /// The default value is 0.0
117 #[unsafe(method(regularizationScale))]
118 #[unsafe(method_family = none)]
119 pub unsafe fn regularizationScale(&self) -> c_float;
120
121 /// Setter for [`regularizationScale`][Self::regularizationScale].
122 #[unsafe(method(setRegularizationScale:))]
123 #[unsafe(method_family = none)]
124 pub unsafe fn setRegularizationScale(&self, regularization_scale: c_float);
125
126 /// The regularizationType which we apply.
127 ///
128 /// The default value is MPSRegularizationTypeNone
129 #[unsafe(method(regularizationType))]
130 #[unsafe(method_family = none)]
131 pub unsafe fn regularizationType(&self) -> MPSNNRegularizationType;
132
133 /// Setter for [`regularizationType`][Self::regularizationType].
134 #[unsafe(method(setRegularizationType:))]
135 #[unsafe(method_family = none)]
136 pub unsafe fn setRegularizationType(&self, regularization_type: MPSNNRegularizationType);
137
138 #[unsafe(method(initWithLearningRate:gradientRescale:regularizationType:regularizationScale:))]
139 #[unsafe(method_family = init)]
140 pub unsafe fn initWithLearningRate_gradientRescale_regularizationType_regularizationScale(
141 this: Allocated<Self>,
142 learning_rate: c_float,
143 gradient_rescale: c_float,
144 regularization_type: MPSNNRegularizationType,
145 regularization_scale: c_float,
146 ) -> Retained<Self>;
147
148 #[unsafe(method(initWithLearningRate:gradientRescale:applyGradientClipping:gradientClipMax:gradientClipMin:regularizationType:regularizationScale:))]
149 #[unsafe(method_family = init)]
150 pub unsafe fn initWithLearningRate_gradientRescale_applyGradientClipping_gradientClipMax_gradientClipMin_regularizationType_regularizationScale(
151 this: Allocated<Self>,
152 learning_rate: c_float,
153 gradient_rescale: c_float,
154 apply_gradient_clipping: bool,
155 gradient_clip_max: c_float,
156 gradient_clip_min: c_float,
157 regularization_type: MPSNNRegularizationType,
158 regularization_scale: c_float,
159 ) -> Retained<Self>;
160
161 #[unsafe(method(optimizerDescriptorWithLearningRate:gradientRescale:regularizationType:regularizationScale:))]
162 #[unsafe(method_family = none)]
163 pub unsafe fn optimizerDescriptorWithLearningRate_gradientRescale_regularizationType_regularizationScale(
164 learning_rate: c_float,
165 gradient_rescale: c_float,
166 regularization_type: MPSNNRegularizationType,
167 regularization_scale: c_float,
168 ) -> Retained<Self>;
169
170 #[unsafe(method(optimizerDescriptorWithLearningRate:gradientRescale:applyGradientClipping:gradientClipMax:gradientClipMin:regularizationType:regularizationScale:))]
171 #[unsafe(method_family = none)]
172 pub unsafe fn optimizerDescriptorWithLearningRate_gradientRescale_applyGradientClipping_gradientClipMax_gradientClipMin_regularizationType_regularizationScale(
173 learning_rate: c_float,
174 gradient_rescale: c_float,
175 apply_gradient_clipping: bool,
176 gradient_clip_max: c_float,
177 gradient_clip_min: c_float,
178 regularization_type: MPSNNRegularizationType,
179 regularization_scale: c_float,
180 ) -> Retained<Self>;
181 );
182}
183
184/// Methods declared on superclass `NSObject`.
185impl MPSNNOptimizerDescriptor {
186 extern_methods!(
187 #[unsafe(method(init))]
188 #[unsafe(method_family = init)]
189 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
190
191 #[unsafe(method(new))]
192 #[unsafe(method_family = new)]
193 pub unsafe fn new() -> Retained<Self>;
194 );
195}
196
197extern_class!(
198 /// The MPSNNOptimizer base class, use one of the child classes, not to be directly used. Optimizers are generally used to update trainable neural network parameters.
199 /// Users are usually expected to call these MPSKernels from the update methods on their Convolution or BatchNormalization data sources.
200 ///
201 /// Before the gradient is used to update the original value, some preprocessing occurs on each gradient where it is scaled or clipped
202 /// If regularization is chosen the appropriate regularization loss gradient is added to the value gradient.
203 ///
204 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizer?language=objc)
205 #[unsafe(super(MPSKernel, NSObject))]
206 #[derive(Debug, PartialEq, Eq, Hash)]
207 #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
208 pub struct MPSNNOptimizer;
209);
210
211#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
212extern_conformance!(
213 unsafe impl NSCoding for MPSNNOptimizer {}
214);
215
216#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
217extern_conformance!(
218 unsafe impl NSCopying for MPSNNOptimizer {}
219);
220
221#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
222unsafe impl CopyingHelper for MPSNNOptimizer {
223 type Result = Self;
224}
225
226#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
227extern_conformance!(
228 unsafe impl NSObjectProtocol for MPSNNOptimizer {}
229);
230
231#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
232extern_conformance!(
233 unsafe impl NSSecureCoding for MPSNNOptimizer {}
234);
235
236#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
237impl MPSNNOptimizer {
238 extern_methods!(
239 /// The learningRate at which we update values
240 ///
241 /// The default value is 1e-3
242 #[unsafe(method(learningRate))]
243 #[unsafe(method_family = none)]
244 pub unsafe fn learningRate(&self) -> c_float;
245
246 /// The gradientRescale at which we apply to incoming gradient values
247 ///
248 /// The default value is 1.0
249 #[unsafe(method(gradientRescale))]
250 #[unsafe(method_family = none)]
251 pub unsafe fn gradientRescale(&self) -> c_float;
252
253 /// A bool which decides if gradient will be clipped
254 ///
255 /// The default value is NO
256 #[unsafe(method(applyGradientClipping))]
257 #[unsafe(method_family = none)]
258 pub unsafe fn applyGradientClipping(&self) -> bool;
259
260 /// Setter for [`applyGradientClipping`][Self::applyGradientClipping].
261 #[unsafe(method(setApplyGradientClipping:))]
262 #[unsafe(method_family = none)]
263 pub unsafe fn setApplyGradientClipping(&self, apply_gradient_clipping: bool);
264
265 /// The maximum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
266 #[unsafe(method(gradientClipMax))]
267 #[unsafe(method_family = none)]
268 pub unsafe fn gradientClipMax(&self) -> c_float;
269
270 /// The minimum value at which incoming gradient will be clipped before rescaling, applyGradientClipping must be true
271 #[unsafe(method(gradientClipMin))]
272 #[unsafe(method_family = none)]
273 pub unsafe fn gradientClipMin(&self) -> c_float;
274
275 /// The regularizationScale at which we apply L1 or L2 regularization, it gets ignored if regularization is None
276 ///
277 /// The default value is 0.0
278 #[unsafe(method(regularizationScale))]
279 #[unsafe(method_family = none)]
280 pub unsafe fn regularizationScale(&self) -> c_float;
281
282 /// The regularizationType which we apply.
283 ///
284 /// The default value is MPSRegularizationTypeNone
285 #[unsafe(method(regularizationType))]
286 #[unsafe(method_family = none)]
287 pub unsafe fn regularizationType(&self) -> MPSNNRegularizationType;
288
289 #[unsafe(method(initWithDevice:))]
290 #[unsafe(method_family = init)]
291 pub unsafe fn initWithDevice(
292 this: Allocated<Self>,
293 device: &ProtocolObject<dyn MTLDevice>,
294 ) -> Retained<Self>;
295
296 #[unsafe(method(setLearningRate:))]
297 #[unsafe(method_family = none)]
298 pub unsafe fn setLearningRate(&self, new_learning_rate: c_float);
299 );
300}
301
302/// Methods declared on superclass `MPSKernel`.
303#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
304impl MPSNNOptimizer {
305 extern_methods!(
306 /// Called by NSCoder to decode MPSKernels
307 ///
308 /// This isn't the right interface to decode a MPSKernel, but
309 /// it is the one that NSCoder uses. To enable your NSCoder
310 /// (e.g. NSKeyedUnarchiver) to set which device to use
311 /// extend the object to adopt the MPSDeviceProvider
312 /// protocol. Otherwise, the Metal system default device
313 /// will be used.
314 ///
315 /// # Safety
316 ///
317 /// `a_decoder` possibly has further requirements.
318 #[unsafe(method(initWithCoder:))]
319 #[unsafe(method_family = init)]
320 pub unsafe fn initWithCoder(
321 this: Allocated<Self>,
322 a_decoder: &NSCoder,
323 ) -> Option<Retained<Self>>;
324
325 /// NSSecureCoding compatability
326 ///
327 /// While the standard NSSecureCoding/NSCoding method
328 /// -initWithCoder: should work, since the file can't
329 /// know which device your data is allocated on, we
330 /// have to guess and may guess incorrectly. To avoid
331 /// that problem, use initWithCoder:device instead.
332 ///
333 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
334 ///
335 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
336 ///
337 /// Returns: A new MPSKernel object, or nil if failure.
338 ///
339 /// # Safety
340 ///
341 /// `a_decoder` possibly has further requirements.
342 #[unsafe(method(initWithCoder:device:))]
343 #[unsafe(method_family = init)]
344 pub unsafe fn initWithCoder_device(
345 this: Allocated<Self>,
346 a_decoder: &NSCoder,
347 device: &ProtocolObject<dyn MTLDevice>,
348 ) -> Option<Retained<Self>>;
349 );
350}
351
352/// Methods declared on superclass `NSObject`.
353#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
354impl MPSNNOptimizer {
355 extern_methods!(
356 #[unsafe(method(init))]
357 #[unsafe(method_family = init)]
358 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
359
360 #[unsafe(method(new))]
361 #[unsafe(method_family = new)]
362 pub unsafe fn new() -> Retained<Self>;
363 );
364}
365
366extern_class!(
367 /// The MPSNNOptimizerStochasticGradientDescent performs a gradient descent with an optional momentum Update
368 /// RMSProp is also known as root mean square propagation.
369 ///
370 /// useNesterov == NO:
371 /// m[t] = momentumScale * m[t-1] + learningRate * g
372 /// variable = variable - m[t]
373 ///
374 /// useNesterov == YES:
375 /// m[t] = momentumScale * m[t-1] + g
376 /// variable = variable - (learningRate * (g + m[t] * momentumScale))
377 ///
378 /// where,
379 /// g is gradient of error wrt variable
380 /// m[t] is momentum of gradients it is a state we keep updating every update iteration
381 ///
382 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerstochasticgradientdescent?language=objc)
383 #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
384 #[derive(Debug, PartialEq, Eq, Hash)]
385 #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
386 pub struct MPSNNOptimizerStochasticGradientDescent;
387);
388
389#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
390extern_conformance!(
391 unsafe impl NSCoding for MPSNNOptimizerStochasticGradientDescent {}
392);
393
394#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
395extern_conformance!(
396 unsafe impl NSCopying for MPSNNOptimizerStochasticGradientDescent {}
397);
398
399#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
400unsafe impl CopyingHelper for MPSNNOptimizerStochasticGradientDescent {
401 type Result = Self;
402}
403
404#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
405extern_conformance!(
406 unsafe impl NSObjectProtocol for MPSNNOptimizerStochasticGradientDescent {}
407);
408
409#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
410extern_conformance!(
411 unsafe impl NSSecureCoding for MPSNNOptimizerStochasticGradientDescent {}
412);
413
414#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
415impl MPSNNOptimizerStochasticGradientDescent {
416 extern_methods!(
417 /// The momentumScale at which we update momentum for values array
418 ///
419 /// Default value is 0.0
420 #[unsafe(method(momentumScale))]
421 #[unsafe(method_family = none)]
422 pub unsafe fn momentumScale(&self) -> c_float;
423
424 /// Nesterov momentum is considered an improvement on the usual momentum update
425 ///
426 /// Default value is NO
427 ///
428 /// Note: Maps to old useNestrovMomentum property
429 #[unsafe(method(useNesterovMomentum))]
430 #[unsafe(method_family = none)]
431 pub unsafe fn useNesterovMomentum(&self) -> bool;
432
433 #[unsafe(method(useNestrovMomentum))]
434 #[unsafe(method_family = none)]
435 pub unsafe fn useNestrovMomentum(&self) -> bool;
436
437 #[unsafe(method(initWithDevice:))]
438 #[unsafe(method_family = init)]
439 pub unsafe fn initWithDevice(
440 this: Allocated<Self>,
441 device: &ProtocolObject<dyn MTLDevice>,
442 ) -> Retained<Self>;
443
444 /// Convenience initialization for the momentum update
445 ///
446 ///
447 /// Parameter `device`: The device on which the kernel will execute.
448 ///
449 /// Parameter `learningRate`: The learningRate which will be applied
450 ///
451 ///
452 /// Returns: A valid MPSNNOptimizerStochasticGradientDescent object or nil, if failure.
453 #[unsafe(method(initWithDevice:learningRate:))]
454 #[unsafe(method_family = init)]
455 pub unsafe fn initWithDevice_learningRate(
456 this: Allocated<Self>,
457 device: &ProtocolObject<dyn MTLDevice>,
458 learning_rate: c_float,
459 ) -> Retained<Self>;
460
461 /// Full initialization for the momentum update
462 ///
463 ///
464 /// Parameter `device`: The device on which the kernel will execute.
465 ///
466 /// Parameter `momentumScale`: The momentumScale to update momentum for values array
467 ///
468 /// Parameter `useNesterovMomentum`: Use the Nesterov style momentum update
469 ///
470 /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
471 ///
472 ///
473 /// Returns: A valid MPSNNOptimizerMomentum object or nil, if failure.
474 #[unsafe(method(initWithDevice:momentumScale:useNesterovMomentum:optimizerDescriptor:))]
475 #[unsafe(method_family = init)]
476 pub unsafe fn initWithDevice_momentumScale_useNesterovMomentum_optimizerDescriptor(
477 this: Allocated<Self>,
478 device: &ProtocolObject<dyn MTLDevice>,
479 momentum_scale: c_float,
480 use_nesterov_momentum: bool,
481 optimizer_descriptor: &MPSNNOptimizerDescriptor,
482 ) -> Retained<Self>;
483
484 #[unsafe(method(initWithDevice:momentumScale:useNestrovMomentum:optimizerDescriptor:))]
485 #[unsafe(method_family = init)]
486 pub unsafe fn initWithDevice_momentumScale_useNestrovMomentum_optimizerDescriptor(
487 this: Allocated<Self>,
488 device: &ProtocolObject<dyn MTLDevice>,
489 momentum_scale: c_float,
490 use_nestrov_momentum: bool,
491 optimizer_descriptor: &MPSNNOptimizerDescriptor,
492 ) -> Retained<Self>;
493
494 #[cfg(feature = "MPSMatrix")]
495 /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
496 ///
497 ///
498 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
499 ///
500 /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
501 ///
502 /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
503 ///
504 /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
505 /// be updated and overwritten.
506 ///
507 /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
508 /// be updated and overwritten.
509 ///
510 ///
511 /// The following operations would be applied
512 ///
513 /// useNesterov == NO:
514 /// m[t] = momentumScale * m[t-1] + learningRate * g
515 /// variable = variable - m[t]
516 ///
517 /// useNesterov == YES:
518 /// m[t] = momentumScale * m[t-1] + g
519 /// variable = variable - (learningRate * (g + m[t] * momentumScale))
520 ///
521 /// inputMomentumVector == nil
522 /// variable = variable - (learningRate * g)
523 ///
524 /// where,
525 /// g is gradient of error wrt variable
526 /// m[t] is momentum of gradients it is a state we keep updating every update iteration
527 #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:resultValuesVector:))]
528 #[unsafe(method_family = none)]
529 pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_resultValuesVector(
530 &self,
531 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
532 input_gradient_vector: &MPSVector,
533 input_values_vector: &MPSVector,
534 input_momentum_vector: Option<&MPSVector>,
535 result_values_vector: &MPSVector,
536 );
537
538 #[cfg(feature = "MPSMatrix")]
539 #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:resultValuesMatrix:))]
540 #[unsafe(method_family = none)]
541 pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_resultValuesMatrix(
542 &self,
543 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
544 input_gradient_matrix: &MPSMatrix,
545 input_values_matrix: &MPSMatrix,
546 input_momentum_matrix: Option<&MPSMatrix>,
547 result_values_matrix: &MPSMatrix,
548 );
549
550 #[cfg(all(
551 feature = "MPSCNNConvolution",
552 feature = "MPSMatrix",
553 feature = "MPSNNGradientState",
554 feature = "MPSState"
555 ))]
556 /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
557 ///
558 ///
559 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
560 ///
561 /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
562 ///
563 /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
564 ///
565 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
566 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
567 /// size 1 in which case biases won't be updated
568 ///
569 /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
570 /// be updated and overwritten.
571 ///
572 ///
573 /// The following operations would be applied
574 ///
575 /// useNesterov == NO:
576 /// m[t] = momentumScale * m[t-1] + learningRate * g
577 /// variable = variable - m[t]
578 ///
579 /// useNesterov == YES:
580 /// m[t] = momentumScale * m[t-1] + g
581 /// variable = variable - (learningRate * (g + m[t] * momentumScale))
582 ///
583 /// inputMomentumVector == nil
584 /// variable = variable - (learningRate * g)
585 ///
586 /// where,
587 /// g is gradient of error wrt variable
588 /// m[t] is momentum of gradients it is a state we keep updating every update iteration
589 #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:resultState:))]
590 #[unsafe(method_family = none)]
591 pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_resultState(
592 &self,
593 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
594 convolution_gradient_state: &MPSCNNConvolutionGradientState,
595 convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
596 input_momentum_vectors: Option<&NSArray<MPSVector>>,
597 result_state: &MPSCNNConvolutionWeightsAndBiasesState,
598 );
599
600 #[cfg(all(
601 feature = "MPSCNNBatchNormalization",
602 feature = "MPSCNNNormalizationWeights",
603 feature = "MPSMatrix",
604 feature = "MPSNNGradientState",
605 feature = "MPSState"
606 ))]
607 /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
608 ///
609 ///
610 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
611 ///
612 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
613 ///
614 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
615 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
616 /// size 1 in which case beta won't be updated
617 ///
618 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
619 /// be updated and overwritten.
620 ///
621 ///
622 /// The following operations would be applied
623 ///
624 /// useNesterov == NO:
625 /// m[t] = momentumScale * m[t-1] + learningRate * g
626 /// variable = variable - m[t]
627 ///
628 /// useNesterov == YES:
629 /// m[t] = momentumScale * m[t-1] + g
630 /// variable = variable - (learningRate * (g + m[t] * momentumScale))
631 ///
632 /// inputMomentumVector == nil
633 /// variable = variable - (learningRate * g)
634 ///
635 /// where,
636 /// g is gradient of error wrt variable
637 /// m[t] is momentum of gradients it is a state we keep updating every update iteration
638 #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:resultState:))]
639 #[unsafe(method_family = none)]
640 pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_resultState(
641 &self,
642 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
643 batch_normalization_state: &MPSCNNBatchNormalizationState,
644 input_momentum_vectors: Option<&NSArray<MPSVector>>,
645 result_state: &MPSCNNNormalizationGammaAndBetaState,
646 );
647
648 #[cfg(all(
649 feature = "MPSCNNBatchNormalization",
650 feature = "MPSCNNNormalizationWeights",
651 feature = "MPSMatrix",
652 feature = "MPSNNGradientState",
653 feature = "MPSState"
654 ))]
655 /// Encode an MPSNNOptimizerStochasticGradientDescent object to a command buffer to perform out of place update
656 ///
657 ///
658 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
659 ///
660 /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
661 ///
662 /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
663 ///
664 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
665 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
666 /// size 1 in which case beta won't be updated
667 ///
668 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
669 /// be updated and overwritten.
670 ///
671 ///
672 /// The following operations would be applied
673 ///
674 /// useNesterov == NO:
675 /// m[t] = momentumScale * m[t-1] + learningRate * g
676 /// variable = variable - m[t]
677 ///
678 /// useNesterov == YES:
679 /// m[t] = momentumScale * m[t-1] + g
680 /// variable = variable - (learningRate * (g + m[t] * momentumScale))
681 ///
682 /// inputMomentumVector == nil
683 /// variable = variable - (learningRate * g)
684 ///
685 /// where,
686 /// g is gradient of error wrt variable
687 /// m[t] is momentum of gradients it is a state we keep updating every update iteration
688 #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:resultState:))]
689 #[unsafe(method_family = none)]
690 pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_resultState(
691 &self,
692 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
693 batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
694 batch_normalization_source_state: &MPSCNNBatchNormalizationState,
695 input_momentum_vectors: Option<&NSArray<MPSVector>>,
696 result_state: &MPSCNNNormalizationGammaAndBetaState,
697 );
698 );
699}
700
701/// Methods declared on superclass `MPSKernel`.
702#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
703impl MPSNNOptimizerStochasticGradientDescent {
704 extern_methods!(
705 /// Called by NSCoder to decode MPSKernels
706 ///
707 /// This isn't the right interface to decode a MPSKernel, but
708 /// it is the one that NSCoder uses. To enable your NSCoder
709 /// (e.g. NSKeyedUnarchiver) to set which device to use
710 /// extend the object to adopt the MPSDeviceProvider
711 /// protocol. Otherwise, the Metal system default device
712 /// will be used.
713 ///
714 /// # Safety
715 ///
716 /// `a_decoder` possibly has further requirements.
717 #[unsafe(method(initWithCoder:))]
718 #[unsafe(method_family = init)]
719 pub unsafe fn initWithCoder(
720 this: Allocated<Self>,
721 a_decoder: &NSCoder,
722 ) -> Option<Retained<Self>>;
723
724 /// NSSecureCoding compatability
725 ///
726 /// While the standard NSSecureCoding/NSCoding method
727 /// -initWithCoder: should work, since the file can't
728 /// know which device your data is allocated on, we
729 /// have to guess and may guess incorrectly. To avoid
730 /// that problem, use initWithCoder:device instead.
731 ///
732 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
733 ///
734 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
735 ///
736 /// Returns: A new MPSKernel object, or nil if failure.
737 ///
738 /// # Safety
739 ///
740 /// `a_decoder` possibly has further requirements.
741 #[unsafe(method(initWithCoder:device:))]
742 #[unsafe(method_family = init)]
743 pub unsafe fn initWithCoder_device(
744 this: Allocated<Self>,
745 a_decoder: &NSCoder,
746 device: &ProtocolObject<dyn MTLDevice>,
747 ) -> Option<Retained<Self>>;
748 );
749}
750
751/// Methods declared on superclass `NSObject`.
752#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
753impl MPSNNOptimizerStochasticGradientDescent {
754 extern_methods!(
755 #[unsafe(method(init))]
756 #[unsafe(method_family = init)]
757 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
758
759 #[unsafe(method(new))]
760 #[unsafe(method_family = new)]
761 pub unsafe fn new() -> Retained<Self>;
762 );
763}
764
765extern_class!(
766 /// The MPSNNOptimizerRMSProp performs an RMSProp Update
767 /// RMSProp is also known as root mean square propagation.
768 ///
769 /// s[t] = decay * s[t-1] + (1 - decay) * (g ^ 2)
770 /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
771 ///
772 /// where,
773 /// g is gradient of error wrt variable
774 /// s[t] is weighted sum of squares of gradients
775 ///
776 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizerrmsprop?language=objc)
777 #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
778 #[derive(Debug, PartialEq, Eq, Hash)]
779 #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
780 pub struct MPSNNOptimizerRMSProp;
781);
782
783#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
784extern_conformance!(
785 unsafe impl NSCoding for MPSNNOptimizerRMSProp {}
786);
787
788#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
789extern_conformance!(
790 unsafe impl NSCopying for MPSNNOptimizerRMSProp {}
791);
792
793#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
794unsafe impl CopyingHelper for MPSNNOptimizerRMSProp {
795 type Result = Self;
796}
797
798#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
799extern_conformance!(
800 unsafe impl NSObjectProtocol for MPSNNOptimizerRMSProp {}
801);
802
803#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
804extern_conformance!(
805 unsafe impl NSSecureCoding for MPSNNOptimizerRMSProp {}
806);
807
808#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
809impl MPSNNOptimizerRMSProp {
810 extern_methods!(
811 /// The decay at which we update sumOfSquares
812 ///
813 /// Default value is 0.9
814 #[unsafe(method(decay))]
815 #[unsafe(method_family = none)]
816 pub unsafe fn decay(&self) -> c_double;
817
818 /// The epsilon at which we update values
819 ///
820 /// This value is usually used to ensure to avoid divide by 0, default value is 1e-8
821 #[unsafe(method(epsilon))]
822 #[unsafe(method_family = none)]
823 pub unsafe fn epsilon(&self) -> c_float;
824
825 #[unsafe(method(initWithDevice:))]
826 #[unsafe(method_family = init)]
827 pub unsafe fn initWithDevice(
828 this: Allocated<Self>,
829 device: &ProtocolObject<dyn MTLDevice>,
830 ) -> Retained<Self>;
831
832 /// Convenience initialization for the RMSProp update
833 ///
834 ///
835 /// Parameter `device`: The device on which the kernel will execute.
836 ///
837 /// Parameter `learningRate`: The learningRate which will be applied
838 ///
839 ///
840 /// Returns: A valid MPSNNOptimizerRMSProp object or nil, if failure.
841 #[unsafe(method(initWithDevice:learningRate:))]
842 #[unsafe(method_family = init)]
843 pub unsafe fn initWithDevice_learningRate(
844 this: Allocated<Self>,
845 device: &ProtocolObject<dyn MTLDevice>,
846 learning_rate: c_float,
847 ) -> Retained<Self>;
848
849 /// Full initialization for the rmsProp update
850 ///
851 ///
852 /// Parameter `device`: The device on which the kernel will execute.
853 ///
854 /// Parameter `decay`: The decay to update sumOfSquares
855 ///
856 /// Parameter `epsilon`: The epsilon which will be applied
857 ///
858 /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
859 ///
860 ///
861 /// Returns: A valid MPSNNOptimizerRMSProp object or nil, if failure.
862 #[unsafe(method(initWithDevice:decay:epsilon:optimizerDescriptor:))]
863 #[unsafe(method_family = init)]
864 pub unsafe fn initWithDevice_decay_epsilon_optimizerDescriptor(
865 this: Allocated<Self>,
866 device: &ProtocolObject<dyn MTLDevice>,
867 decay: c_double,
868 epsilon: c_float,
869 optimizer_descriptor: &MPSNNOptimizerDescriptor,
870 ) -> Retained<Self>;
871
872 #[cfg(feature = "MPSMatrix")]
873 /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
874 ///
875 ///
876 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
877 ///
878 /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
879 ///
880 /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
881 ///
882 /// Parameter `inputSumOfSquaresVector`: A valid MPSVector object which specifies the gradient velocity vector which will
883 /// be updated and overwritten.
884 ///
885 /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
886 /// be updated and overwritten.
887 ///
888 ///
889 /// The following operations would be applied
890 ///
891 /// s[t] = decay * s[t-1] + (1 - decay) * (g ^ 2)
892 /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
893 ///
894 /// where,
895 /// g is gradient of error wrt variable
896 /// s[t] is weighted sum of squares of gradients
897 #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputSumOfSquaresVector:resultValuesVector:))]
898 #[unsafe(method_family = none)]
899 pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputSumOfSquaresVector_resultValuesVector(
900 &self,
901 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
902 input_gradient_vector: &MPSVector,
903 input_values_vector: &MPSVector,
904 input_sum_of_squares_vector: &MPSVector,
905 result_values_vector: &MPSVector,
906 );
907
908 #[cfg(feature = "MPSMatrix")]
909 #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputSumOfSquaresMatrix:resultValuesMatrix:))]
910 #[unsafe(method_family = none)]
911 pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputSumOfSquaresMatrix_resultValuesMatrix(
912 &self,
913 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
914 input_gradient_matrix: &MPSMatrix,
915 input_values_matrix: &MPSMatrix,
916 input_sum_of_squares_matrix: &MPSMatrix,
917 result_values_matrix: &MPSMatrix,
918 );
919
920 #[cfg(all(
921 feature = "MPSCNNConvolution",
922 feature = "MPSMatrix",
923 feature = "MPSNNGradientState",
924 feature = "MPSState"
925 ))]
926 /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
927 ///
928 ///
929 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
930 ///
931 /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
932 ///
933 /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
934 ///
935 /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
936 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
937 /// size 1 in which case biases won't be updated
938 ///
939 /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
940 /// be updated and overwritten.
941 ///
942 ///
943 /// The following operations would be applied
944 ///
945 /// s[t] = decay * s[t-1] + (1 - decay) * (g ^ 2)
946 /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
947 ///
948 /// where,
949 /// g is gradient of error wrt variable
950 /// s[t] is weighted sum of squares of gradients
951 #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputSumOfSquaresVectors:resultState:))]
952 #[unsafe(method_family = none)]
953 pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputSumOfSquaresVectors_resultState(
954 &self,
955 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
956 convolution_gradient_state: &MPSCNNConvolutionGradientState,
957 convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
958 input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
959 result_state: &MPSCNNConvolutionWeightsAndBiasesState,
960 );
961
962 #[cfg(all(
963 feature = "MPSCNNBatchNormalization",
964 feature = "MPSCNNNormalizationWeights",
965 feature = "MPSMatrix",
966 feature = "MPSNNGradientState",
967 feature = "MPSState"
968 ))]
969 /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
970 ///
971 ///
972 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
973 ///
974 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
975 ///
976 /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
977 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
978 /// size 1 in which case beta won't be updated
979 ///
980 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
981 /// be updated and overwritten.
982 ///
983 ///
984 /// The following operations would be applied
985 ///
986 /// s[t] = decay * s[t-1] + (1 - decay) * (g ^ 2)
987 /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
988 ///
989 /// where,
990 /// g is gradient of error wrt variable
991 /// s[t] is weighted sum of squares of gradients
992 #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputSumOfSquaresVectors:resultState:))]
993 #[unsafe(method_family = none)]
994 pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputSumOfSquaresVectors_resultState(
995 &self,
996 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
997 batch_normalization_state: &MPSCNNBatchNormalizationState,
998 input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
999 result_state: &MPSCNNNormalizationGammaAndBetaState,
1000 );
1001
1002 #[cfg(all(
1003 feature = "MPSCNNBatchNormalization",
1004 feature = "MPSCNNNormalizationWeights",
1005 feature = "MPSMatrix",
1006 feature = "MPSNNGradientState",
1007 feature = "MPSState"
1008 ))]
1009 /// Encode an MPSNNOptimizerRMSProp object to a command buffer to perform out of place update
1010 ///
1011 ///
1012 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1013 ///
1014 /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1015 ///
1016 /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1017 ///
1018 /// Parameter `inputSumOfSquaresVectors`: An array MPSVector object which specifies the gradient sumOfSquares vectors which will
1019 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1020 /// size 1 in which case beta won't be updated
1021 ///
1022 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1023 /// be updated and overwritten.
1024 ///
1025 ///
1026 /// The following operations would be applied
1027 ///
1028 /// s[t] = decay * s[t-1] + (1 - decay) * (g ^ 2)
1029 /// variable = variable - learningRate * g / (sqrt(s[t]) + epsilon)
1030 ///
1031 /// where,
1032 /// g is gradient of error wrt variable
1033 /// s[t] is weighted sum of squares of gradients
1034 #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputSumOfSquaresVectors:resultState:))]
1035 #[unsafe(method_family = none)]
1036 pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputSumOfSquaresVectors_resultState(
1037 &self,
1038 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1039 batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1040 batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1041 input_sum_of_squares_vectors: Option<&NSArray<MPSVector>>,
1042 result_state: &MPSCNNNormalizationGammaAndBetaState,
1043 );
1044 );
1045}
1046
1047/// Methods declared on superclass `MPSKernel`.
1048#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1049impl MPSNNOptimizerRMSProp {
1050 extern_methods!(
1051 /// Called by NSCoder to decode MPSKernels
1052 ///
1053 /// This isn't the right interface to decode a MPSKernel, but
1054 /// it is the one that NSCoder uses. To enable your NSCoder
1055 /// (e.g. NSKeyedUnarchiver) to set which device to use
1056 /// extend the object to adopt the MPSDeviceProvider
1057 /// protocol. Otherwise, the Metal system default device
1058 /// will be used.
1059 ///
1060 /// # Safety
1061 ///
1062 /// `a_decoder` possibly has further requirements.
1063 #[unsafe(method(initWithCoder:))]
1064 #[unsafe(method_family = init)]
1065 pub unsafe fn initWithCoder(
1066 this: Allocated<Self>,
1067 a_decoder: &NSCoder,
1068 ) -> Option<Retained<Self>>;
1069
1070 /// NSSecureCoding compatability
1071 ///
1072 /// While the standard NSSecureCoding/NSCoding method
1073 /// -initWithCoder: should work, since the file can't
1074 /// know which device your data is allocated on, we
1075 /// have to guess and may guess incorrectly. To avoid
1076 /// that problem, use initWithCoder:device instead.
1077 ///
1078 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1079 ///
1080 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1081 ///
1082 /// Returns: A new MPSKernel object, or nil if failure.
1083 ///
1084 /// # Safety
1085 ///
1086 /// `a_decoder` possibly has further requirements.
1087 #[unsafe(method(initWithCoder:device:))]
1088 #[unsafe(method_family = init)]
1089 pub unsafe fn initWithCoder_device(
1090 this: Allocated<Self>,
1091 a_decoder: &NSCoder,
1092 device: &ProtocolObject<dyn MTLDevice>,
1093 ) -> Option<Retained<Self>>;
1094 );
1095}
1096
1097/// Methods declared on superclass `NSObject`.
1098#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1099impl MPSNNOptimizerRMSProp {
1100 extern_methods!(
1101 #[unsafe(method(init))]
1102 #[unsafe(method_family = init)]
1103 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1104
1105 #[unsafe(method(new))]
1106 #[unsafe(method_family = new)]
1107 pub unsafe fn new() -> Retained<Self>;
1108 );
1109}
1110
1111extern_class!(
1112 /// The MPSNNOptimizerAdam performs an Adam Update
1113 ///
1114 /// Initialization time
1115 /// m[0] = 0 (Initialize initial 1st moment vector aka momentum, user is responsible for this)
1116 /// v[0] = 0 (Initialize initial 2nd moment vector aka velocity, user is responsible for this)
1117 /// t = 0 (Initialize timestep)
1118 ///
1119 /// https://arxiv.org/abs/1412.6980
1120 ///
1121 /// At update time:
1122 /// t = t + 1
1123 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1124 ///
1125 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1126 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1127 /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1128 ///
1129 /// where,
1130 /// g is gradient of error wrt variable
1131 /// v[t] is velocity
1132 /// m[t] is momentum
1133 ///
1134 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpsnnoptimizeradam?language=objc)
1135 #[unsafe(super(MPSNNOptimizer, MPSKernel, NSObject))]
1136 #[derive(Debug, PartialEq, Eq, Hash)]
1137 #[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1138 pub struct MPSNNOptimizerAdam;
1139);
1140
1141#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1142extern_conformance!(
1143 unsafe impl NSCoding for MPSNNOptimizerAdam {}
1144);
1145
1146#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1147extern_conformance!(
1148 unsafe impl NSCopying for MPSNNOptimizerAdam {}
1149);
1150
1151#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1152unsafe impl CopyingHelper for MPSNNOptimizerAdam {
1153 type Result = Self;
1154}
1155
1156#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1157extern_conformance!(
1158 unsafe impl NSObjectProtocol for MPSNNOptimizerAdam {}
1159);
1160
1161#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1162extern_conformance!(
1163 unsafe impl NSSecureCoding for MPSNNOptimizerAdam {}
1164);
1165
1166#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1167impl MPSNNOptimizerAdam {
1168 extern_methods!(
1169 /// The beta1 at which we update values
1170 ///
1171 /// Default value is 0.9
1172 #[unsafe(method(beta1))]
1173 #[unsafe(method_family = none)]
1174 pub unsafe fn beta1(&self) -> c_double;
1175
1176 /// The beta2 at which we update values
1177 ///
1178 /// Default value is 0.999
1179 #[unsafe(method(beta2))]
1180 #[unsafe(method_family = none)]
1181 pub unsafe fn beta2(&self) -> c_double;
1182
1183 /// The epsilon at which we update values
1184 ///
1185 /// This value is usually used to ensure to avoid divide by 0, default value is 1e-8
1186 #[unsafe(method(epsilon))]
1187 #[unsafe(method_family = none)]
1188 pub unsafe fn epsilon(&self) -> c_float;
1189
1190 /// Current timeStep for the update, number of times update has occurred
1191 #[unsafe(method(timeStep))]
1192 #[unsafe(method_family = none)]
1193 pub unsafe fn timeStep(&self) -> NSUInteger;
1194
1195 /// Setter for [`timeStep`][Self::timeStep].
1196 #[unsafe(method(setTimeStep:))]
1197 #[unsafe(method_family = none)]
1198 pub unsafe fn setTimeStep(&self, time_step: NSUInteger);
1199
1200 #[unsafe(method(initWithDevice:))]
1201 #[unsafe(method_family = init)]
1202 pub unsafe fn initWithDevice(
1203 this: Allocated<Self>,
1204 device: &ProtocolObject<dyn MTLDevice>,
1205 ) -> Retained<Self>;
1206
1207 /// Convenience initialization for the adam update
1208 ///
1209 ///
1210 /// Parameter `device`: The device on which the kernel will execute.
1211 ///
1212 /// Parameter `learningRate`: The learningRate at which we will update values
1213 ///
1214 ///
1215 /// Returns: A valid MPSNNOptimizerAdam object or nil, if failure.
1216 #[unsafe(method(initWithDevice:learningRate:))]
1217 #[unsafe(method_family = init)]
1218 pub unsafe fn initWithDevice_learningRate(
1219 this: Allocated<Self>,
1220 device: &ProtocolObject<dyn MTLDevice>,
1221 learning_rate: c_float,
1222 ) -> Retained<Self>;
1223
1224 /// Full initialization for the adam update
1225 ///
1226 ///
1227 /// Parameter `device`: The device on which the kernel will execute.
1228 ///
1229 /// Parameter `beta1`: The beta1 to update values
1230 ///
1231 /// Parameter `beta2`: The beta2 to update values
1232 ///
1233 /// Parameter `epsilon`: The epsilon at which we update values
1234 ///
1235 /// Parameter `timeStep`: The timeStep at which values will start updating
1236 ///
1237 /// Parameter `optimizerDescriptor`: The optimizerDescriptor which will have a bunch of properties to be applied
1238 ///
1239 ///
1240 /// Returns: A valid MPSNNOptimizerAdam object or nil, if failure.
1241 #[unsafe(method(initWithDevice:beta1:beta2:epsilon:timeStep:optimizerDescriptor:))]
1242 #[unsafe(method_family = init)]
1243 pub unsafe fn initWithDevice_beta1_beta2_epsilon_timeStep_optimizerDescriptor(
1244 this: Allocated<Self>,
1245 device: &ProtocolObject<dyn MTLDevice>,
1246 beta1: c_double,
1247 beta2: c_double,
1248 epsilon: c_float,
1249 time_step: NSUInteger,
1250 optimizer_descriptor: &MPSNNOptimizerDescriptor,
1251 ) -> Retained<Self>;
1252
1253 #[cfg(feature = "MPSMatrix")]
1254 /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1255 ///
1256 ///
1257 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1258 ///
1259 /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
1260 ///
1261 /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
1262 ///
1263 /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
1264 /// be updated and overwritten.
1265 ///
1266 /// Parameter `inputVelocityVector`: A valid MPSVector object which specifies the gradient velocity vector which will
1267 /// be updated and overwritten.
1268 ///
1269 /// Parameter `resultValuesVector`: A valid MPSVector object which specifies the resultValues vector which will
1270 /// be updated and overwritten.
1271 ///
1272 ///
1273 /// The following operations would be applied
1274 ///
1275 /// t = t + 1
1276 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1277 ///
1278 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1279 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1280 /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1281 #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:inputVelocityVector:resultValuesVector:))]
1282 #[unsafe(method_family = none)]
1283 pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_inputVelocityVector_resultValuesVector(
1284 &self,
1285 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1286 input_gradient_vector: &MPSVector,
1287 input_values_vector: &MPSVector,
1288 input_momentum_vector: &MPSVector,
1289 input_velocity_vector: &MPSVector,
1290 result_values_vector: &MPSVector,
1291 );
1292
1293 #[cfg(feature = "MPSMatrix")]
1294 #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:inputVelocityMatrix:resultValuesMatrix:))]
1295 #[unsafe(method_family = none)]
1296 pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_inputVelocityMatrix_resultValuesMatrix(
1297 &self,
1298 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1299 input_gradient_matrix: &MPSMatrix,
1300 input_values_matrix: &MPSMatrix,
1301 input_momentum_matrix: &MPSMatrix,
1302 input_velocity_matrix: &MPSMatrix,
1303 result_values_matrix: &MPSMatrix,
1304 );
1305
1306 #[cfg(feature = "MPSMatrix")]
1307 /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1308 ///
1309 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1310 ///
1311 /// Parameter `inputGradientVector`: A valid MPSVector object which specifies the input vector of gradients for this update.
1312 ///
1313 /// Parameter `inputValuesVector`: A valid MPSVector object which specifies the input vector of values to be updated.
1314 ///
1315 /// Parameter `inputMomentumVector`: A valid MPSVector object which specifies the gradient momentum vector which will
1316 /// be updated and overwritten.
1317 ///
1318 /// Parameter `inputVelocityVector`: A valid MPSVector object which specifies the gradient velocity vector which will
1319 /// be updated and overwritten.
1320 ///
1321 /// Parameter `maximumVelocityVector`: A valid MPSVector object which specifies the maximum velocity vector which will
1322 /// be updated and overwritten. May be nil, if nil then normal Adam optimizer behaviour is followed.
1323 ///
1324 /// Parameter `resultValuesVector`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1325 /// be updated and overwritten.
1326 ///
1327 ///
1328 /// The following operations would be applied
1329 /// At update time:
1330 /// t = t + 1
1331 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1332 ///
1333 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1334 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1335 /// maxVel[t] = max(maxVel[t-1],v[t])
1336 /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1337 #[unsafe(method(encodeToCommandBuffer:inputGradientVector:inputValuesVector:inputMomentumVector:inputVelocityVector:maximumVelocityVector:resultValuesVector:))]
1338 #[unsafe(method_family = none)]
1339 pub unsafe fn encodeToCommandBuffer_inputGradientVector_inputValuesVector_inputMomentumVector_inputVelocityVector_maximumVelocityVector_resultValuesVector(
1340 &self,
1341 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1342 input_gradient_vector: &MPSVector,
1343 input_values_vector: &MPSVector,
1344 input_momentum_vector: &MPSVector,
1345 input_velocity_vector: &MPSVector,
1346 maximum_velocity_vector: Option<&MPSVector>,
1347 result_values_vector: &MPSVector,
1348 );
1349
1350 #[cfg(feature = "MPSMatrix")]
1351 #[unsafe(method(encodeToCommandBuffer:inputGradientMatrix:inputValuesMatrix:inputMomentumMatrix:inputVelocityMatrix:maximumVelocityMatrix:resultValuesMatrix:))]
1352 #[unsafe(method_family = none)]
1353 pub unsafe fn encodeToCommandBuffer_inputGradientMatrix_inputValuesMatrix_inputMomentumMatrix_inputVelocityMatrix_maximumVelocityMatrix_resultValuesMatrix(
1354 &self,
1355 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1356 input_gradient_matrix: &MPSMatrix,
1357 input_values_matrix: &MPSMatrix,
1358 input_momentum_matrix: &MPSMatrix,
1359 input_velocity_matrix: &MPSMatrix,
1360 maximum_velocity_matrix: Option<&MPSMatrix>,
1361 result_values_matrix: &MPSMatrix,
1362 );
1363
1364 #[cfg(all(
1365 feature = "MPSCNNConvolution",
1366 feature = "MPSMatrix",
1367 feature = "MPSNNGradientState",
1368 feature = "MPSState"
1369 ))]
1370 /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1371 ///
1372 ///
1373 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1374 ///
1375 /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
1376 ///
1377 /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
1378 ///
1379 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1380 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1381 /// size 1 in which case biases won't be updated
1382 ///
1383 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1384 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1385 /// size 1 in which case biases won't be updated
1386 ///
1387 /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1388 /// be updated and overwritten.
1389 ///
1390 ///
1391 /// The following operations would be applied
1392 ///
1393 /// t = t + 1
1394 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1395 ///
1396 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1397 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1398 /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1399 #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1400 #[unsafe(method_family = none)]
1401 pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_inputVelocityVectors_resultState(
1402 &self,
1403 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1404 convolution_gradient_state: &MPSCNNConvolutionGradientState,
1405 convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
1406 input_momentum_vectors: Option<&NSArray<MPSVector>>,
1407 input_velocity_vectors: Option<&NSArray<MPSVector>>,
1408 result_state: &MPSCNNConvolutionWeightsAndBiasesState,
1409 );
1410
1411 #[cfg(all(
1412 feature = "MPSCNNConvolution",
1413 feature = "MPSMatrix",
1414 feature = "MPSNNGradientState",
1415 feature = "MPSState"
1416 ))]
1417 /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1418 ///
1419 ///
1420 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1421 ///
1422 /// Parameter `convolutionGradientState`: A valid MPSCNNConvolutionGradientState object which specifies the input state with gradients for this update.
1423 ///
1424 /// Parameter `convolutionSourceState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the input state with values to be updated.
1425 ///
1426 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1427 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1428 /// size 1 in which case biases won't be updated
1429 ///
1430 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1431 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1432 /// size 1 in which case biases won't be updated
1433 ///
1434 /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1435 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1436 /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1437 ///
1438 /// Parameter `resultState`: A valid MPSCNNConvolutionWeightsAndBiasesState object which specifies the resultValues state which will
1439 /// be updated and overwritten.
1440 ///
1441 ///
1442 /// The following operations would be applied
1443 /// At update time:
1444 /// t = t + 1
1445 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1446 ///
1447 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1448 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1449 /// maxVel[t] = max(maxVel[t-1],v[t])
1450 /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1451 #[unsafe(method(encodeToCommandBuffer:convolutionGradientState:convolutionSourceState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1452 #[unsafe(method_family = none)]
1453 pub unsafe fn encodeToCommandBuffer_convolutionGradientState_convolutionSourceState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1454 &self,
1455 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1456 convolution_gradient_state: &MPSCNNConvolutionGradientState,
1457 convolution_source_state: &MPSCNNConvolutionWeightsAndBiasesState,
1458 input_momentum_vectors: &NSArray<MPSVector>,
1459 input_velocity_vectors: &NSArray<MPSVector>,
1460 maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1461 result_state: &MPSCNNConvolutionWeightsAndBiasesState,
1462 );
1463
1464 #[cfg(all(
1465 feature = "MPSCNNBatchNormalization",
1466 feature = "MPSCNNNormalizationWeights",
1467 feature = "MPSMatrix",
1468 feature = "MPSNNGradientState",
1469 feature = "MPSState"
1470 ))]
1471 /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1472 ///
1473 ///
1474 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1475 ///
1476 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
1477 ///
1478 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1479 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1480 /// size 1 in which case beta won't be updated
1481 ///
1482 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1483 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1484 /// size 1 in which case beta won't be updated
1485 ///
1486 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1487 /// be updated and overwritten.
1488 ///
1489 ///
1490 /// The following operations would be applied
1491 ///
1492 /// t = t + 1
1493 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1494 ///
1495 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1496 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1497 /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1498 #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1499 #[unsafe(method_family = none)]
1500 pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_inputVelocityVectors_resultState(
1501 &self,
1502 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1503 batch_normalization_state: &MPSCNNBatchNormalizationState,
1504 input_momentum_vectors: Option<&NSArray<MPSVector>>,
1505 input_velocity_vectors: Option<&NSArray<MPSVector>>,
1506 result_state: &MPSCNNNormalizationGammaAndBetaState,
1507 );
1508
1509 #[cfg(all(
1510 feature = "MPSCNNBatchNormalization",
1511 feature = "MPSCNNNormalizationWeights",
1512 feature = "MPSMatrix",
1513 feature = "MPSNNGradientState",
1514 feature = "MPSState"
1515 ))]
1516 /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1517 ///
1518 ///
1519 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1520 ///
1521 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients and original gamma/beta for this update.
1522 ///
1523 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1524 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1525 /// size 1 in which case beta won't be updated
1526 ///
1527 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1528 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1529 /// size 1 in which case beta won't be updated
1530 ///
1531 /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1532 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1533 /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1534 ///
1535 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1536 /// be updated and overwritten.
1537 ///
1538 ///
1539 /// The following operations would be applied
1540 /// At update time:
1541 /// t = t + 1
1542 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1543 ///
1544 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1545 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1546 /// maxVel[t] = max(maxVel[t-1],v[t])
1547 /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1548 #[unsafe(method(encodeToCommandBuffer:batchNormalizationState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1549 #[unsafe(method_family = none)]
1550 pub unsafe fn encodeToCommandBuffer_batchNormalizationState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1551 &self,
1552 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1553 batch_normalization_state: &MPSCNNBatchNormalizationState,
1554 input_momentum_vectors: &NSArray<MPSVector>,
1555 input_velocity_vectors: &NSArray<MPSVector>,
1556 maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1557 result_state: &MPSCNNNormalizationGammaAndBetaState,
1558 );
1559
1560 #[cfg(all(
1561 feature = "MPSCNNBatchNormalization",
1562 feature = "MPSCNNNormalizationWeights",
1563 feature = "MPSMatrix",
1564 feature = "MPSNNGradientState",
1565 feature = "MPSState"
1566 ))]
1567 /// Encode an MPSNNOptimizerAdam object to a command buffer to perform out of place update
1568 ///
1569 ///
1570 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1571 ///
1572 /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1573 ///
1574 /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1575 ///
1576 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1577 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1578 /// size 1 in which case beta won't be updated
1579 ///
1580 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1581 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1582 /// size 1 in which case beta won't be updated
1583 ///
1584 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1585 /// be updated and overwritten.
1586 ///
1587 ///
1588 /// The following operations would be applied
1589 ///
1590 /// t = t + 1
1591 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1592 ///
1593 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1594 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1595 /// variable = variable - lr[t] * m[t] / (sqrt(v[t]) + epsilon)
1596 #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:inputVelocityVectors:resultState:))]
1597 #[unsafe(method_family = none)]
1598 pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_inputVelocityVectors_resultState(
1599 &self,
1600 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1601 batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1602 batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1603 input_momentum_vectors: Option<&NSArray<MPSVector>>,
1604 input_velocity_vectors: Option<&NSArray<MPSVector>>,
1605 result_state: &MPSCNNNormalizationGammaAndBetaState,
1606 );
1607
1608 #[cfg(all(
1609 feature = "MPSCNNBatchNormalization",
1610 feature = "MPSCNNNormalizationWeights",
1611 feature = "MPSMatrix",
1612 feature = "MPSNNGradientState",
1613 feature = "MPSState"
1614 ))]
1615 /// Encode an AMSGrad variant of MPSNNOptimizerAdam object to a command buffer to perform out of place update
1616 ///
1617 ///
1618 /// Parameter `commandBuffer`: A valid MTLCommandBuffer to receive the encoded kernel.
1619 ///
1620 /// Parameter `batchNormalizationGradientState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with gradients for this update.
1621 ///
1622 /// Parameter `batchNormalizationSourceState`: A valid MPSCNNBatchNormalizationState object which specifies the input state with original gamma/beta for this update.
1623 ///
1624 /// Parameter `inputMomentumVectors`: An array MPSVector object which specifies the gradient momentum vectors which will
1625 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1626 /// size 1 in which case beta won't be updated
1627 ///
1628 /// Parameter `inputVelocityVectors`: An array MPSVector object which specifies the gradient velocity vectors which will
1629 /// be updated and overwritten. The index 0 corresponds to gamma, index 1 corresponds to beta, array can be of
1630 /// size 1 in which case beta won't be updated
1631 ///
1632 /// Parameter `maximumVelocityVectors`: An array MPSVector object which specifies the maximum velocity vectors which will
1633 /// be updated and overwritten. The index 0 corresponds to weights, index 1 corresponds to biases, array can be of
1634 /// size 1 in which case biases won't be updated. May be nil, if nil then normal Adam optimizer behaviour is followed.
1635 ///
1636 /// Parameter `resultState`: A valid MPSCNNNormalizationGammaAndBetaState object which specifies the resultValues state which will
1637 /// be updated and overwritten.
1638 ///
1639 ///
1640 /// The following operations would be applied
1641 /// At update time:
1642 /// t = t + 1
1643 /// lr[t] = learningRate * sqrt(1 - beta2^t) / (1 - beta1^t)
1644 ///
1645 /// m[t] = beta1 * m[t-1] + (1 - beta1) * g
1646 /// v[t] = beta2 * v[t-1] + (1 - beta2) * (g ^ 2)
1647 /// maxVel[t] = max(maxVel[t-1],v[t])
1648 /// variable = variable - lr[t] * m[t] / (sqrt(maxVel[t]) + epsilon)
1649 #[unsafe(method(encodeToCommandBuffer:batchNormalizationGradientState:batchNormalizationSourceState:inputMomentumVectors:inputVelocityVectors:maximumVelocityVectors:resultState:))]
1650 #[unsafe(method_family = none)]
1651 pub unsafe fn encodeToCommandBuffer_batchNormalizationGradientState_batchNormalizationSourceState_inputMomentumVectors_inputVelocityVectors_maximumVelocityVectors_resultState(
1652 &self,
1653 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1654 batch_normalization_gradient_state: &MPSCNNBatchNormalizationState,
1655 batch_normalization_source_state: &MPSCNNBatchNormalizationState,
1656 input_momentum_vectors: &NSArray<MPSVector>,
1657 input_velocity_vectors: &NSArray<MPSVector>,
1658 maximum_velocity_vectors: Option<&NSArray<MPSVector>>,
1659 result_state: &MPSCNNNormalizationGammaAndBetaState,
1660 );
1661 );
1662}
1663
1664/// Methods declared on superclass `MPSKernel`.
1665#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1666impl MPSNNOptimizerAdam {
1667 extern_methods!(
1668 /// Called by NSCoder to decode MPSKernels
1669 ///
1670 /// This isn't the right interface to decode a MPSKernel, but
1671 /// it is the one that NSCoder uses. To enable your NSCoder
1672 /// (e.g. NSKeyedUnarchiver) to set which device to use
1673 /// extend the object to adopt the MPSDeviceProvider
1674 /// protocol. Otherwise, the Metal system default device
1675 /// will be used.
1676 ///
1677 /// # Safety
1678 ///
1679 /// `a_decoder` possibly has further requirements.
1680 #[unsafe(method(initWithCoder:))]
1681 #[unsafe(method_family = init)]
1682 pub unsafe fn initWithCoder(
1683 this: Allocated<Self>,
1684 a_decoder: &NSCoder,
1685 ) -> Option<Retained<Self>>;
1686
1687 /// NSSecureCoding compatability
1688 ///
1689 /// While the standard NSSecureCoding/NSCoding method
1690 /// -initWithCoder: should work, since the file can't
1691 /// know which device your data is allocated on, we
1692 /// have to guess and may guess incorrectly. To avoid
1693 /// that problem, use initWithCoder:device instead.
1694 ///
1695 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1696 ///
1697 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1698 ///
1699 /// Returns: A new MPSKernel object, or nil if failure.
1700 ///
1701 /// # Safety
1702 ///
1703 /// `a_decoder` possibly has further requirements.
1704 #[unsafe(method(initWithCoder:device:))]
1705 #[unsafe(method_family = init)]
1706 pub unsafe fn initWithCoder_device(
1707 this: Allocated<Self>,
1708 a_decoder: &NSCoder,
1709 device: &ProtocolObject<dyn MTLDevice>,
1710 ) -> Option<Retained<Self>>;
1711 );
1712}
1713
1714/// Methods declared on superclass `NSObject`.
1715#[cfg(all(feature = "MPSCore", feature = "MPSKernel"))]
1716impl MPSNNOptimizerAdam {
1717 extern_methods!(
1718 #[unsafe(method(init))]
1719 #[unsafe(method_family = init)]
1720 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1721
1722 #[unsafe(method(new))]
1723 #[unsafe(method_family = new)]
1724 pub unsafe fn new() -> Retained<Self>;
1725 );
1726}