objc2_metal_performance_shaders/generated/MPSNeuralNetwork/MPSCNNBatchNormalization.rs
1//! This file has been automatically generated by `objc2`'s `header-translator`.
2//! DO NOT EDIT
3use core::ffi::*;
4use core::ptr::NonNull;
5use objc2::__framework_prelude::*;
6use objc2_foundation::*;
7use objc2_metal::*;
8
9use crate::*;
10
11extern_class!(
12 /// MPSCNNBatchNormalizationState encapsulates the data necessary
13 /// to execute batch normalization.
14 ///
15 /// MPSCNNBatchNormalizationState cannot initialize the size of its own
16 /// underlying resources. Use [MPSCNNBatchNormalizationStatistics resultStateForSourceImages:]
17 /// or [MPSCNNBatchNormalizationStatistics temporaryResultStateForCommandBuffer:sourceImages:].
18 ///
19 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalizationstate?language=objc)
20 #[unsafe(super(MPSNNGradientState, MPSState, NSObject))]
21 #[derive(Debug, PartialEq, Eq, Hash)]
22 #[cfg(all(
23 feature = "MPSCore",
24 feature = "MPSNNGradientState",
25 feature = "MPSState"
26 ))]
27 pub struct MPSCNNBatchNormalizationState;
28);
29
30#[cfg(all(
31 feature = "MPSCore",
32 feature = "MPSNNGradientState",
33 feature = "MPSState"
34))]
35extern_conformance!(
36 unsafe impl NSObjectProtocol for MPSCNNBatchNormalizationState {}
37);
38
39#[cfg(all(
40 feature = "MPSCore",
41 feature = "MPSNNGradientState",
42 feature = "MPSState"
43))]
44impl MPSCNNBatchNormalizationState {
45 extern_methods!(
46 #[cfg(all(feature = "MPSCNNKernel", feature = "MPSKernel"))]
47 #[unsafe(method(batchNormalization))]
48 #[unsafe(method_family = none)]
49 pub unsafe fn batchNormalization(&self) -> Retained<MPSCNNBatchNormalization>;
50
51 /// Unavailable. Use MPSCNNBatchNormalizationStatistics methods to initialize the state object.
52 ///
53 /// # Safety
54 ///
55 /// - `resource` may need to be synchronized.
56 /// - `resource` may be unretained, you must ensure it is kept alive while in use.
57 #[unsafe(method(initWithResource:))]
58 #[unsafe(method_family = init)]
59 pub unsafe fn initWithResource(
60 this: Allocated<Self>,
61 resource: Option<&ProtocolObject<dyn MTLResource>>,
62 ) -> Retained<Self>;
63
64 /// Unavailable. Use MPSCNNBatchNormalizationStatistics methods to create the temporary state object.
65 #[unsafe(method(temporaryStateWithCommandBuffer:bufferSize:))]
66 #[unsafe(method_family = none)]
67 pub unsafe fn temporaryStateWithCommandBuffer_bufferSize(
68 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
69 buffer_size: usize,
70 ) -> Retained<Self>;
71
72 #[unsafe(method(temporaryStateWithCommandBuffer:textureDescriptor:))]
73 #[unsafe(method_family = none)]
74 pub unsafe fn temporaryStateWithCommandBuffer_textureDescriptor(
75 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
76 descriptor: &MTLTextureDescriptor,
77 ) -> Retained<Self>;
78
79 /// Reset any accumulated state data to its initial values.
80 #[unsafe(method(reset))]
81 #[unsafe(method_family = none)]
82 pub unsafe fn reset(&self);
83
84 /// Return an MTLBuffer object with the state's current gamma values.
85 #[unsafe(method(gamma))]
86 #[unsafe(method_family = none)]
87 pub unsafe fn gamma(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
88
89 /// Return an MTLBuffer object with the state's current beta values..
90 #[unsafe(method(beta))]
91 #[unsafe(method_family = none)]
92 pub unsafe fn beta(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
93
94 /// Return an MTLBuffer object with the most recently computed batch mean values.
95 #[unsafe(method(mean))]
96 #[unsafe(method_family = none)]
97 pub unsafe fn mean(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
98
99 /// Return an MTLBuffer object with the most recently computed batch variance values.
100 #[unsafe(method(variance))]
101 #[unsafe(method_family = none)]
102 pub unsafe fn variance(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
103
104 /// Return an MTLBuffer object containing the values of the gradient of the loss function
105 /// with respect to the scale factors. If a MPSCNNBatchNormalizationGradient kernel
106 /// has not successfully generated these values nil will be returned.
107 #[unsafe(method(gradientForGamma))]
108 #[unsafe(method_family = none)]
109 pub unsafe fn gradientForGamma(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
110
111 /// Return an MTLBuffer object containing the values of the gradient of the loss function
112 /// with respect to the bias terms. If a MPSCNNBatchNormalizationGradient kernel
113 /// has not successfully generated these values nil will be returned.
114 #[unsafe(method(gradientForBeta))]
115 #[unsafe(method_family = none)]
116 pub unsafe fn gradientForBeta(&self) -> Option<Retained<ProtocolObject<dyn MTLBuffer>>>;
117 );
118}
119
120/// Methods declared on superclass `MPSState`.
121#[cfg(all(
122 feature = "MPSCore",
123 feature = "MPSNNGradientState",
124 feature = "MPSState"
125))]
126impl MPSCNNBatchNormalizationState {
127 extern_methods!(
128 /// Create a new autoreleased temporary state object without underlying resource
129 ///
130 /// Parameter `cmdBuf`: The command buffer with which the temporary resource is associated
131 #[unsafe(method(temporaryStateWithCommandBuffer:))]
132 #[unsafe(method_family = none)]
133 pub unsafe fn temporaryStateWithCommandBuffer(
134 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
135 ) -> Retained<Self>;
136
137 #[unsafe(method(initWithDevice:bufferSize:))]
138 #[unsafe(method_family = init)]
139 pub unsafe fn initWithDevice_bufferSize(
140 this: Allocated<Self>,
141 device: &ProtocolObject<dyn MTLDevice>,
142 buffer_size: usize,
143 ) -> Retained<Self>;
144
145 #[unsafe(method(initWithDevice:textureDescriptor:))]
146 #[unsafe(method_family = init)]
147 pub unsafe fn initWithDevice_textureDescriptor(
148 this: Allocated<Self>,
149 device: &ProtocolObject<dyn MTLDevice>,
150 descriptor: &MTLTextureDescriptor,
151 ) -> Retained<Self>;
152
153 #[unsafe(method(init))]
154 #[unsafe(method_family = init)]
155 pub unsafe fn init(this: Allocated<Self>) -> Option<Retained<Self>>;
156
157 /// Initialize a non-temporary state to hold a number of textures and buffers
158 ///
159 /// The allocation of each resource will be deferred until it is needed.
160 /// This occurs when -resource or -resourceAtIndex: is called.
161 ///
162 /// Parameter `resourceList`: The list of resources to create.
163 #[unsafe(method(initWithDevice:resourceList:))]
164 #[unsafe(method_family = init)]
165 pub unsafe fn initWithDevice_resourceList(
166 this: Allocated<Self>,
167 device: &ProtocolObject<dyn MTLDevice>,
168 resource_list: &MPSStateResourceList,
169 ) -> Retained<Self>;
170
171 /// Initialize a temporary state to hold a number of textures and buffers
172 ///
173 /// The textures occur first in sequence
174 #[unsafe(method(temporaryStateWithCommandBuffer:resourceList:))]
175 #[unsafe(method_family = none)]
176 pub unsafe fn temporaryStateWithCommandBuffer_resourceList(
177 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
178 resource_list: &MPSStateResourceList,
179 ) -> Retained<Self>;
180
181 /// Create a state object with a list of MTLResources
182 ///
183 /// Because MPS prefers deferred allocation of resources
184 /// your application should use -initWithTextures:bufferSizes:bufferCount:
185 /// whenever possible. This method is useful for cases when the
186 /// MTLResources must be initialized by the CPU.
187 ///
188 /// # Safety
189 ///
190 /// - `resources` generic may need to be synchronized.
191 /// - `resources` generic may be unretained, you must ensure it is kept alive while in use.
192 #[unsafe(method(initWithResources:))]
193 #[unsafe(method_family = init)]
194 pub unsafe fn initWithResources(
195 this: Allocated<Self>,
196 resources: Option<&NSArray<ProtocolObject<dyn MTLResource>>>,
197 ) -> Retained<Self>;
198 );
199}
200
201/// Methods declared on superclass `NSObject`.
202#[cfg(all(
203 feature = "MPSCore",
204 feature = "MPSNNGradientState",
205 feature = "MPSState"
206))]
207impl MPSCNNBatchNormalizationState {
208 extern_methods!(
209 #[unsafe(method(new))]
210 #[unsafe(method_family = new)]
211 pub unsafe fn new() -> Retained<Self>;
212 );
213}
214
215extern_class!(
216 /// A state which contains mean and variance terms used to apply a
217 /// normalization in a MPSCNNBatchNormalization operation.
218 ///
219 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnnormalizationmeanandvariancestate?language=objc)
220 #[unsafe(super(MPSState, NSObject))]
221 #[derive(Debug, PartialEq, Eq, Hash)]
222 #[cfg(all(feature = "MPSCore", feature = "MPSState"))]
223 pub struct MPSCNNNormalizationMeanAndVarianceState;
224);
225
226#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
227extern_conformance!(
228 unsafe impl NSObjectProtocol for MPSCNNNormalizationMeanAndVarianceState {}
229);
230
231#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
232impl MPSCNNNormalizationMeanAndVarianceState {
233 extern_methods!(
234 /// A MTLBuffer containing the mean terms.
235 #[unsafe(method(mean))]
236 #[unsafe(method_family = none)]
237 pub unsafe fn mean(&self) -> Retained<ProtocolObject<dyn MTLBuffer>>;
238
239 /// A MTLBuffer containing the variance terms.
240 #[unsafe(method(variance))]
241 #[unsafe(method_family = none)]
242 pub unsafe fn variance(&self) -> Retained<ProtocolObject<dyn MTLBuffer>>;
243
244 /// Initialize a MPSCNNNormalizationMeanAndVarianceState object using values
245 /// contained in MTLBuffers.
246 ///
247 ///
248 /// Parameter `mean`: The MTLBuffer containing mean terms.
249 ///
250 ///
251 /// Parameter `variance`: The MTLBuffer containing variance terms.
252 ///
253 /// # Safety
254 ///
255 /// - `mean` may need to be synchronized.
256 /// - `mean` may be unretained, you must ensure it is kept alive while in use.
257 /// - `mean` contents should be of the correct type.
258 /// - `variance` may need to be synchronized.
259 /// - `variance` may be unretained, you must ensure it is kept alive while in use.
260 /// - `variance` contents should be of the correct type.
261 #[unsafe(method(initWithMean:variance:))]
262 #[unsafe(method_family = init)]
263 pub unsafe fn initWithMean_variance(
264 this: Allocated<Self>,
265 mean: &ProtocolObject<dyn MTLBuffer>,
266 variance: &ProtocolObject<dyn MTLBuffer>,
267 ) -> Retained<Self>;
268
269 /// Create a temporary MPSCNNNormalizationMeanAndVarianceState suitable
270 /// for a normalization operation on images containing no more than
271 /// the specified number of feature channels.
272 ///
273 ///
274 /// Parameter `commandBuffer`: The command buffer on which the temporary state will
275 /// be used.
276 ///
277 ///
278 /// Parameter `numberOfFeatureChannels`: The number of feature channels used to size the
279 /// state.
280 #[unsafe(method(temporaryStateWithCommandBuffer:numberOfFeatureChannels:))]
281 #[unsafe(method_family = none)]
282 pub unsafe fn temporaryStateWithCommandBuffer_numberOfFeatureChannels(
283 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
284 number_of_feature_channels: NSUInteger,
285 ) -> Retained<Self>;
286 );
287}
288
289/// Methods declared on superclass `MPSState`.
290#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
291impl MPSCNNNormalizationMeanAndVarianceState {
292 extern_methods!(
293 /// Create a MPSState holding a temporary MTLBuffer
294 ///
295 /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
296 ///
297 /// Parameter `bufferSize`: The size of the buffer in bytes
298 #[unsafe(method(temporaryStateWithCommandBuffer:bufferSize:))]
299 #[unsafe(method_family = none)]
300 pub unsafe fn temporaryStateWithCommandBuffer_bufferSize(
301 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
302 buffer_size: usize,
303 ) -> Retained<Self>;
304
305 /// Create a MPSState holding a temporary MTLTexture
306 ///
307 /// Parameter `cmdBuf`: The command buffer against which the temporary resource is allocated
308 ///
309 /// Parameter `descriptor`: A descriptor for the new temporary texture
310 #[unsafe(method(temporaryStateWithCommandBuffer:textureDescriptor:))]
311 #[unsafe(method_family = none)]
312 pub unsafe fn temporaryStateWithCommandBuffer_textureDescriptor(
313 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
314 descriptor: &MTLTextureDescriptor,
315 ) -> Retained<Self>;
316
317 /// Create a new autoreleased temporary state object without underlying resource
318 ///
319 /// Parameter `cmdBuf`: The command buffer with which the temporary resource is associated
320 #[unsafe(method(temporaryStateWithCommandBuffer:))]
321 #[unsafe(method_family = none)]
322 pub unsafe fn temporaryStateWithCommandBuffer(
323 cmd_buf: &ProtocolObject<dyn MTLCommandBuffer>,
324 ) -> Retained<Self>;
325
326 #[unsafe(method(initWithDevice:bufferSize:))]
327 #[unsafe(method_family = init)]
328 pub unsafe fn initWithDevice_bufferSize(
329 this: Allocated<Self>,
330 device: &ProtocolObject<dyn MTLDevice>,
331 buffer_size: usize,
332 ) -> Retained<Self>;
333
334 #[unsafe(method(initWithDevice:textureDescriptor:))]
335 #[unsafe(method_family = init)]
336 pub unsafe fn initWithDevice_textureDescriptor(
337 this: Allocated<Self>,
338 device: &ProtocolObject<dyn MTLDevice>,
339 descriptor: &MTLTextureDescriptor,
340 ) -> Retained<Self>;
341
342 /// Create a MPSState with a non-temporary MTLResource
343 ///
344 /// Parameter `resource`: A MTLBuffer or MTLTexture. May be nil.
345 ///
346 /// # Safety
347 ///
348 /// - `resource` may need to be synchronized.
349 /// - `resource` may be unretained, you must ensure it is kept alive while in use.
350 #[unsafe(method(initWithResource:))]
351 #[unsafe(method_family = init)]
352 pub unsafe fn initWithResource(
353 this: Allocated<Self>,
354 resource: Option<&ProtocolObject<dyn MTLResource>>,
355 ) -> Retained<Self>;
356
357 #[unsafe(method(init))]
358 #[unsafe(method_family = init)]
359 pub unsafe fn init(this: Allocated<Self>) -> Option<Retained<Self>>;
360
361 /// Initialize a non-temporary state to hold a number of textures and buffers
362 ///
363 /// The allocation of each resource will be deferred until it is needed.
364 /// This occurs when -resource or -resourceAtIndex: is called.
365 ///
366 /// Parameter `resourceList`: The list of resources to create.
367 #[unsafe(method(initWithDevice:resourceList:))]
368 #[unsafe(method_family = init)]
369 pub unsafe fn initWithDevice_resourceList(
370 this: Allocated<Self>,
371 device: &ProtocolObject<dyn MTLDevice>,
372 resource_list: &MPSStateResourceList,
373 ) -> Retained<Self>;
374
375 /// Initialize a temporary state to hold a number of textures and buffers
376 ///
377 /// The textures occur first in sequence
378 #[unsafe(method(temporaryStateWithCommandBuffer:resourceList:))]
379 #[unsafe(method_family = none)]
380 pub unsafe fn temporaryStateWithCommandBuffer_resourceList(
381 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
382 resource_list: &MPSStateResourceList,
383 ) -> Retained<Self>;
384
385 /// Create a state object with a list of MTLResources
386 ///
387 /// Because MPS prefers deferred allocation of resources
388 /// your application should use -initWithTextures:bufferSizes:bufferCount:
389 /// whenever possible. This method is useful for cases when the
390 /// MTLResources must be initialized by the CPU.
391 ///
392 /// # Safety
393 ///
394 /// - `resources` generic may need to be synchronized.
395 /// - `resources` generic may be unretained, you must ensure it is kept alive while in use.
396 #[unsafe(method(initWithResources:))]
397 #[unsafe(method_family = init)]
398 pub unsafe fn initWithResources(
399 this: Allocated<Self>,
400 resources: Option<&NSArray<ProtocolObject<dyn MTLResource>>>,
401 ) -> Retained<Self>;
402 );
403}
404
405/// Methods declared on superclass `NSObject`.
406#[cfg(all(feature = "MPSCore", feature = "MPSState"))]
407impl MPSCNNNormalizationMeanAndVarianceState {
408 extern_methods!(
409 #[unsafe(method(new))]
410 #[unsafe(method_family = new)]
411 pub unsafe fn new() -> Retained<Self>;
412 );
413}
414
415extern_protocol!(
416 /// The MPSCNNBatchNormalizationDataSource protocol declares the methods that an
417 /// instance of MPSCNNBatchNormalizationState uses to initialize the
418 /// scale factors, bias terms, and batch statistics.
419 ///
420 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalizationdatasource?language=objc)
421 pub unsafe trait MPSCNNBatchNormalizationDataSource:
422 NSObjectProtocol + NSCopying
423 {
424 /// Returns the number of feature channels within images to be normalized
425 /// using the supplied parameters.
426 #[unsafe(method(numberOfFeatureChannels))]
427 #[unsafe(method_family = none)]
428 unsafe fn numberOfFeatureChannels(&self) -> NSUInteger;
429
430 /// Returns a pointer to the scale factors for the batch normalization.
431 #[unsafe(method(gamma))]
432 #[unsafe(method_family = none)]
433 unsafe fn gamma(&self) -> *mut c_float;
434
435 /// Returns a pointer to the bias terms for the batch normalization.
436 /// If NULL then no bias is to be applied.
437 #[unsafe(method(beta))]
438 #[unsafe(method_family = none)]
439 unsafe fn beta(&self) -> *mut c_float;
440
441 /// Returns a pointer to batch mean values with which to initialize
442 /// the state for a subsequent batch normalization.
443 #[unsafe(method(mean))]
444 #[unsafe(method_family = none)]
445 unsafe fn mean(&self) -> *mut c_float;
446
447 /// Returns a pointer to batch variance values with which to initialize
448 /// the state for a subsequent batch normalization.
449 #[unsafe(method(variance))]
450 #[unsafe(method_family = none)]
451 unsafe fn variance(&self) -> *mut c_float;
452
453 /// Alerts the data source that the data will be needed soon
454 ///
455 /// Each load alert will be balanced by a purge later, when MPS
456 /// no longer needs the data from this object.
457 /// Load will always be called atleast once after initial construction
458 /// or each purge of the object before anything else is called.
459 ///
460 /// Returns: Returns YES on success. If NO is returned, expect MPS
461 /// object construction to fail.
462 #[unsafe(method(load))]
463 #[unsafe(method_family = none)]
464 unsafe fn load(&self) -> bool;
465
466 /// Alerts the data source that the data is no longer needed
467 ///
468 /// Each load alert will be balanced by a purge later, when MPS
469 /// no longer needs the data from this object.
470 #[unsafe(method(purge))]
471 #[unsafe(method_family = none)]
472 unsafe fn purge(&self);
473
474 /// A label that is transferred to the batch normalization filter at init time
475 ///
476 /// Overridden by a MPSCNNBatchNormalizationNode.label if it is non-nil.
477 #[unsafe(method(label))]
478 #[unsafe(method_family = none)]
479 unsafe fn label(&self) -> Option<Retained<NSString>>;
480
481 #[cfg(all(
482 feature = "MPSCNNNormalizationWeights",
483 feature = "MPSCore",
484 feature = "MPSNNGradientState",
485 feature = "MPSState"
486 ))]
487 /// Compute new gamma and beta values using current values and gradients contained within a
488 /// MPSCNNBatchNormalizationState. Perform the update using a GPU.
489 ///
490 /// This operation is expected to also decrement the read count of batchNormalizationState by 1.
491 ///
492 ///
493 /// Parameter `commandBuffer`: The command buffer on which to encode the update.
494 ///
495 ///
496 /// Parameter `batchNormalizationState`: The MPSCNNBatchNormalizationState object containing the current gamma and
497 /// beta values and the gradient values.
498 ///
499 ///
500 /// Returns: A MPSCNNNormalizationMeanAndVarianceState object containing updated mean and variance values. If NULL, the MPSNNGraph
501 /// batch normalization filter gamma and beta values will remain unmodified.
502 #[optional]
503 #[unsafe(method(updateGammaAndBetaWithCommandBuffer:batchNormalizationState:))]
504 #[unsafe(method_family = none)]
505 unsafe fn updateGammaAndBetaWithCommandBuffer_batchNormalizationState(
506 &self,
507 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
508 batch_normalization_state: &MPSCNNBatchNormalizationState,
509 ) -> Option<Retained<MPSCNNNormalizationGammaAndBetaState>>;
510
511 #[cfg(all(
512 feature = "MPSCore",
513 feature = "MPSNNGradientState",
514 feature = "MPSState"
515 ))]
516 /// Compute new mean and variance values using current batch statistics contained within a
517 /// MPSCNNBatchNormalizationState. Perform the update using a GPU.
518 ///
519 /// This operation is expected to also decrement the read count of batchNormalizationState by 1.
520 ///
521 ///
522 /// Parameter `commandBuffer`: The command buffer on which to encode the update.
523 ///
524 ///
525 /// Parameter `batchNormalizationState`: The MPSCNNBatchNormalizationState object containing the current batch statistics.
526 ///
527 ///
528 /// Returns: A MPSCNNNormalizationMeanAndVarianceState object containing updated mean and variance values. If NULL, the MPSNNGraph
529 /// batch normalization filter mean and variance values will remain unmodified.
530 #[optional]
531 #[unsafe(method(updateMeanAndVarianceWithCommandBuffer:batchNormalizationState:))]
532 #[unsafe(method_family = none)]
533 unsafe fn updateMeanAndVarianceWithCommandBuffer_batchNormalizationState(
534 &self,
535 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
536 batch_normalization_state: &MPSCNNBatchNormalizationState,
537 ) -> Option<Retained<MPSCNNNormalizationMeanAndVarianceState>>;
538
539 #[cfg(all(
540 feature = "MPSCore",
541 feature = "MPSNNGradientState",
542 feature = "MPSState"
543 ))]
544 /// Compute new gamma and beta values using current values and gradients contained within a
545 /// MPSCNNBatchNormalizationState. Perform the update using the CPU.
546 ///
547 ///
548 /// Parameter `batchNormalizationState`: The MPSCNNBatchNormalizationState object containing the current gamma and
549 /// beta values and the gradient values.
550 ///
551 ///
552 /// Returns: A boolean value indicating if the update was performed.
553 #[optional]
554 #[unsafe(method(updateGammaAndBetaWithBatchNormalizationState:))]
555 #[unsafe(method_family = none)]
556 unsafe fn updateGammaAndBetaWithBatchNormalizationState(
557 &self,
558 batch_normalization_state: &MPSCNNBatchNormalizationState,
559 ) -> bool;
560
561 #[cfg(all(
562 feature = "MPSCore",
563 feature = "MPSNNGradientState",
564 feature = "MPSState"
565 ))]
566 /// Compute new mean and variance values using current batch statistics contained within a
567 /// MPSCNNBatchNormalizationState. Perform the update using the CPU.
568 ///
569 ///
570 /// Parameter `batchNormalizationState`: The MPSCNNBatchNormalizationState object containing the current batch statistics.
571 ///
572 ///
573 /// Returns: A boolean value indicating if the update was performed.
574 #[optional]
575 #[unsafe(method(updateMeanAndVarianceWithBatchNormalizationState:))]
576 #[unsafe(method_family = none)]
577 unsafe fn updateMeanAndVarianceWithBatchNormalizationState(
578 &self,
579 batch_normalization_state: &MPSCNNBatchNormalizationState,
580 ) -> bool;
581
582 /// An optional tiny number to use to maintain numerical stability.
583 ///
584 /// output_image = (input_image - mean[c]) * gamma[c] / sqrt(variance[c] + epsilon) + beta[c];
585 /// Defalt value if method unavailable: FLT_MIN
586 #[optional]
587 #[unsafe(method(epsilon))]
588 #[unsafe(method_family = none)]
589 unsafe fn epsilon(&self) -> c_float;
590
591 /// NSSecureCoding compatibility.
592 ///
593 /// # Safety
594 ///
595 /// `a_coder` possibly has further requirements.
596 #[optional]
597 #[unsafe(method(encodeWithCoder:))]
598 #[unsafe(method_family = none)]
599 unsafe fn encodeWithCoder(&self, a_coder: &NSCoder);
600
601 /// NSSecureCoding compatibility.
602 ///
603 /// # Safety
604 ///
605 /// `a_decoder` possibly has further requirements.
606 #[optional]
607 #[unsafe(method(initWithCoder:))]
608 #[unsafe(method_family = init)]
609 unsafe fn initWithCoder(
610 this: Allocated<Self>,
611 a_decoder: &NSCoder,
612 ) -> Option<Retained<Self>>;
613
614 /// NSSecureCoding compatibility.
615 #[optional]
616 #[unsafe(method(supportsSecureCoding))]
617 #[unsafe(method_family = none)]
618 unsafe fn supportsSecureCoding() -> bool;
619
620 /// Optional copy method to create a copy of the data source for use with a new device.
621 ///
622 ///
623 /// Parameter `zone`: The NSZone on which to allocate.
624 ///
625 /// Parameter `device`: The device where the kernel which uses this data source will be used.
626 ///
627 ///
628 /// Returns: A pointer to a copy of this data source.
629 ///
630 /// # Safety
631 ///
632 /// `zone` must be a valid pointer or null.
633 #[optional]
634 #[unsafe(method(copyWithZone:device:))]
635 #[unsafe(method_family = copy)]
636 unsafe fn copyWithZone_device(
637 &self,
638 zone: *mut NSZone,
639 device: Option<&ProtocolObject<dyn MTLDevice>>,
640 ) -> Retained<Self>;
641 }
642);
643
644extern_class!(
645 /// Dependencies: This depends on Metal.framework
646 ///
647 /// MPSCNNBatchNormalization normalizes input images using per-channel
648 /// means and variances.
649 ///
650 /// for (c = 0; c
651 /// <
652 /// numberOfFeatureChannels; ++c)
653 /// {
654 /// input_image = in(:,:,c,:);
655 /// output_image = (input_image - mean[c]) * gamma[c] / sqrt(variance[c] + epsilon) + beta[c];
656 /// out(:,:,c,:) = output_image;
657 /// }
658 ///
659 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalization?language=objc)
660 #[unsafe(super(MPSCNNKernel, MPSKernel, NSObject))]
661 #[derive(Debug, PartialEq, Eq, Hash)]
662 #[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
663 pub struct MPSCNNBatchNormalization;
664);
665
666#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
667extern_conformance!(
668 unsafe impl NSCoding for MPSCNNBatchNormalization {}
669);
670
671#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
672extern_conformance!(
673 unsafe impl NSCopying for MPSCNNBatchNormalization {}
674);
675
676#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
677unsafe impl CopyingHelper for MPSCNNBatchNormalization {
678 type Result = Self;
679}
680
681#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
682extern_conformance!(
683 unsafe impl NSObjectProtocol for MPSCNNBatchNormalization {}
684);
685
686#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
687extern_conformance!(
688 unsafe impl NSSecureCoding for MPSCNNBatchNormalization {}
689);
690
691#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
692impl MPSCNNBatchNormalization {
693 extern_methods!(
694 /// The number of feature channels in an image to be normalized.
695 #[unsafe(method(numberOfFeatureChannels))]
696 #[unsafe(method_family = none)]
697 pub unsafe fn numberOfFeatureChannels(&self) -> NSUInteger;
698
699 /// The epsilon value used in the batch normalization formula to
700 /// bias the variance when normalizing.
701 #[unsafe(method(epsilon))]
702 #[unsafe(method_family = none)]
703 pub unsafe fn epsilon(&self) -> c_float;
704
705 /// Setter for [`epsilon`][Self::epsilon].
706 #[unsafe(method(setEpsilon:))]
707 #[unsafe(method_family = none)]
708 pub unsafe fn setEpsilon(&self, epsilon: c_float);
709
710 /// The data source the batch normalization was initialized with
711 #[unsafe(method(dataSource))]
712 #[unsafe(method_family = none)]
713 pub unsafe fn dataSource(
714 &self,
715 ) -> Retained<ProtocolObject<dyn MPSCNNBatchNormalizationDataSource>>;
716
717 /// Initializes a batch normalization kernel using a data source.
718 ///
719 /// Parameter `device`: The MTLDevice on which this filter will be used
720 ///
721 /// Parameter `dataSource`: A pointer to a object that conforms to the MPSCNNBatchNormalizationDataSource
722 /// protocol. The data source provides filter weights and bias terms and, optionally,
723 /// image statistics which may be used to perform the normalization.
724 ///
725 ///
726 /// Returns: A valid MPSCNNBatchNormalization object or nil, if failure.
727 #[unsafe(method(initWithDevice:dataSource:))]
728 #[unsafe(method_family = init)]
729 pub unsafe fn initWithDevice_dataSource(
730 this: Allocated<Self>,
731 device: &ProtocolObject<dyn MTLDevice>,
732 data_source: &ProtocolObject<dyn MPSCNNBatchNormalizationDataSource>,
733 ) -> Retained<Self>;
734
735 #[cfg(feature = "MPSCNNNeuron")]
736 /// Initializes a batch normalization kernel using a data source and a neuron descriptor.
737 ///
738 /// Parameter `device`: The MTLDevice on which this filter will be used
739 ///
740 /// Parameter `dataSource`: A pointer to a object that conforms to the MPSCNNBatchNormalizationDataSource
741 /// protocol. The data source provides filter weights and bias terms and, optionally,
742 /// image statistics which may be used to perform the normalization.
743 ///
744 /// Parameter `fusedNeuronDescriptor`: A MPSNNNeuronDescriptor object which specifies a neuron activation function to
745 /// be applied to the result of the batch normalization.
746 ///
747 ///
748 /// Returns: A valid MPSCNNBatchNormalization object or nil, if failure.
749 #[unsafe(method(initWithDevice:dataSource:fusedNeuronDescriptor:))]
750 #[unsafe(method_family = init)]
751 pub unsafe fn initWithDevice_dataSource_fusedNeuronDescriptor(
752 this: Allocated<Self>,
753 device: &ProtocolObject<dyn MTLDevice>,
754 data_source: &ProtocolObject<dyn MPSCNNBatchNormalizationDataSource>,
755 fused_neuron_descriptor: Option<&MPSNNNeuronDescriptor>,
756 ) -> Retained<Self>;
757
758 /// Use initWithDevice:dataSource instead
759 #[unsafe(method(initWithDevice:))]
760 #[unsafe(method_family = init)]
761 pub unsafe fn initWithDevice(
762 this: Allocated<Self>,
763 device: &ProtocolObject<dyn MTLDevice>,
764 ) -> Retained<Self>;
765
766 /// NSSecureCoding compatability
767 ///
768 /// While the standard NSSecureCoding/NSCoding method
769 /// -initWithCoder: should work, since the file can't
770 /// know which device your data is allocated on, we
771 /// have to guess and may guess incorrectly. To avoid
772 /// that problem, use a subclass of NSCoder that
773 /// implements the
774 /// <MPSDeviceProvider
775 /// > protocol to
776 /// tell MPS the MTLDevice to use.
777 ///
778 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
779 ///
780 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
781 ///
782 /// Returns: A new MPSCNNBatchNormalization object, or nil if failure.
783 ///
784 /// # Safety
785 ///
786 /// `a_decoder` possibly has further requirements.
787 #[unsafe(method(initWithCoder:device:))]
788 #[unsafe(method_family = init)]
789 pub unsafe fn initWithCoder_device(
790 this: Allocated<Self>,
791 a_decoder: &NSCoder,
792 device: &ProtocolObject<dyn MTLDevice>,
793 ) -> Option<Retained<Self>>;
794
795 #[cfg(all(
796 feature = "MPSImage",
797 feature = "MPSNNGradientState",
798 feature = "MPSState"
799 ))]
800 /// Encode this kernel to a command buffer for a single image using
801 /// a batch normalization state.
802 ///
803 ///
804 /// Parameter `commandBuffer`: A valid command buffer to receive the kernel.
805 ///
806 /// Parameter `sourceImage`: The source MPSImage.
807 ///
808 /// Parameter `batchNormalizationState`: A MPSCNNBatchNormalizationState containing weights and/or
809 /// statistics to use for the batch normalization. If the state
810 /// is temporary its read count will be decremented.
811 ///
812 /// Parameter `destinationImage`: An MPSImage to contain the resulting normalized and scaled
813 /// image.
814 #[unsafe(method(encodeToCommandBuffer:sourceImage:batchNormalizationState:destinationImage:))]
815 #[unsafe(method_family = none)]
816 pub unsafe fn encodeToCommandBuffer_sourceImage_batchNormalizationState_destinationImage(
817 &self,
818 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
819 source_image: &MPSImage,
820 batch_normalization_state: &MPSCNNBatchNormalizationState,
821 destination_image: &MPSImage,
822 );
823
824 #[cfg(all(
825 feature = "MPSImage",
826 feature = "MPSNDArray",
827 feature = "MPSNNGradientState",
828 feature = "MPSState"
829 ))]
830 /// Encode this kernel to a command buffer for a batch of images using
831 /// a batch normalization state.
832 ///
833 ///
834 /// Parameter `commandBuffer`: A valid command buffer to receive the kernel.
835 ///
836 /// Parameter `sourceImages`: The batch of source images.
837 ///
838 /// Parameter `batchNormalizationState`: A MPSCNNBatchNormalizationState containing weights and/or
839 /// statistics to use for the batch normalization. If the state
840 /// is temporary its read count will be decremented.
841 ///
842 /// Parameter `destinationImages`: The batch of images to contain the normalized and scaled
843 /// result images.
844 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:batchNormalizationState:destinationImages:))]
845 #[unsafe(method_family = none)]
846 pub unsafe fn encodeBatchToCommandBuffer_sourceImages_batchNormalizationState_destinationImages(
847 &self,
848 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
849 source_images: &MPSImageBatch,
850 batch_normalization_state: &MPSCNNBatchNormalizationState,
851 destination_images: &MPSImageBatch,
852 );
853
854 #[cfg(all(feature = "MPSImage", feature = "MPSState"))]
855 #[unsafe(method(encodeToCommandBuffer:sourceImage:destinationState:destinationImage:))]
856 #[unsafe(method_family = none)]
857 pub unsafe fn encodeToCommandBuffer_sourceImage_destinationState_destinationImage(
858 &self,
859 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
860 source_image: &MPSImage,
861 destination_state: &MPSState,
862 destination_image: &MPSImage,
863 );
864
865 #[cfg(all(feature = "MPSImage", feature = "MPSState"))]
866 #[unsafe(method(encodeToCommandBuffer:sourceImage:destinationState:destinationStateIsTemporary:))]
867 #[unsafe(method_family = none)]
868 pub unsafe fn encodeToCommandBuffer_sourceImage_destinationState_destinationStateIsTemporary(
869 &self,
870 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
871 source_image: &MPSImage,
872 out_state: &mut Option<Retained<MPSState>>,
873 is_temporary: bool,
874 ) -> Retained<MPSImage>;
875
876 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray", feature = "MPSState"))]
877 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:destinationStates:destinationImages:))]
878 #[unsafe(method_family = none)]
879 pub unsafe fn encodeBatchToCommandBuffer_sourceImages_destinationStates_destinationImages(
880 &self,
881 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
882 source_images: &MPSImageBatch,
883 destination_states: Option<&MPSStateBatch>,
884 destination_images: &MPSImageBatch,
885 );
886
887 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray", feature = "MPSState"))]
888 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:destinationStates:destinationStateIsTemporary:))]
889 #[unsafe(method_family = none)]
890 pub unsafe fn encodeBatchToCommandBuffer_sourceImages_destinationStates_destinationStateIsTemporary(
891 &self,
892 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
893 source_images: &MPSImageBatch,
894 out_states: &mut Option<Retained<MPSStateBatch>>,
895 is_temporary: bool,
896 ) -> Retained<MPSImageBatch>;
897
898 #[cfg(all(
899 feature = "MPSImage",
900 feature = "MPSNNGradientState",
901 feature = "MPSState"
902 ))]
903 /// Return an MPSCNNBatchNormalizationState object which may be used with a MPSCNNBatchNormalization filter.
904 #[unsafe(method(resultStateForSourceImage:sourceStates:destinationImage:))]
905 #[unsafe(method_family = none)]
906 pub unsafe fn resultStateForSourceImage_sourceStates_destinationImage(
907 &self,
908 source_image: &MPSImage,
909 source_states: Option<&NSArray<MPSState>>,
910 destination_image: &MPSImage,
911 ) -> Option<Retained<MPSCNNBatchNormalizationState>>;
912
913 #[cfg(all(
914 feature = "MPSImage",
915 feature = "MPSNNGradientState",
916 feature = "MPSState"
917 ))]
918 /// Return a temporary MPSCNNBatchNormalizationState object which may be used with
919 /// a MPSCNNBatchNormalization filter.
920 #[unsafe(method(temporaryResultStateForCommandBuffer:sourceImage:sourceStates:destinationImage:))]
921 #[unsafe(method_family = none)]
922 pub unsafe fn temporaryResultStateForCommandBuffer_sourceImage_sourceStates_destinationImage(
923 &self,
924 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
925 source_image: &MPSImage,
926 source_states: Option<&NSArray<MPSState>>,
927 destination_image: &MPSImage,
928 ) -> Option<Retained<MPSCNNBatchNormalizationState>>;
929
930 /// Reinitialize the filter using a data source.
931 ///
932 ///
933 /// Parameter `dataSource`: The data source which will provide the weights and, optionally, the
934 /// image batch statistics with which to normalize.
935 #[deprecated]
936 #[unsafe(method(reloadDataSource:))]
937 #[unsafe(method_family = none)]
938 pub unsafe fn reloadDataSource(
939 &self,
940 data_source: &ProtocolObject<dyn MPSCNNBatchNormalizationDataSource>,
941 );
942
943 /// Reinitialize the filter's gamma and beta values using the data source provided at kernel initialization.
944 #[unsafe(method(reloadGammaAndBetaFromDataSource))]
945 #[unsafe(method_family = none)]
946 pub unsafe fn reloadGammaAndBetaFromDataSource(&self);
947
948 /// Reinitialize the filter's mean and variance values using the data source provided at kernel initialization.
949 #[unsafe(method(reloadMeanAndVarianceFromDataSource))]
950 #[unsafe(method_family = none)]
951 pub unsafe fn reloadMeanAndVarianceFromDataSource(&self);
952
953 #[cfg(all(feature = "MPSCNNNormalizationWeights", feature = "MPSState"))]
954 /// Reload data using new gamma and beta terms contained within an
955 /// MPSCNNNormalizationGammaAndBetaState object.
956 ///
957 ///
958 /// Parameter `commandBuffer`: The command buffer on which to encode the reload.
959 ///
960 ///
961 /// Parameter `gammaAndBetaState`: The state containing the updated weights which are to
962 /// be reloaded.
963 #[unsafe(method(reloadGammaAndBetaWithCommandBuffer:gammaAndBetaState:))]
964 #[unsafe(method_family = none)]
965 pub unsafe fn reloadGammaAndBetaWithCommandBuffer_gammaAndBetaState(
966 &self,
967 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
968 gamma_and_beta_state: &MPSCNNNormalizationGammaAndBetaState,
969 );
970
971 #[cfg(feature = "MPSState")]
972 /// Reload data using new mean and variance terms contained within an
973 /// MPSCNNNormalizationMeanAndVarianceState object.
974 ///
975 ///
976 /// Parameter `commandBuffer`: The command buffer on which to encode the reload.
977 ///
978 ///
979 /// Parameter `meanAndVarianceState`: The state containing the updated statistics which are to
980 /// be reloaded.
981 #[unsafe(method(reloadMeanAndVarianceWithCommandBuffer:meanAndVarianceState:))]
982 #[unsafe(method_family = none)]
983 pub unsafe fn reloadMeanAndVarianceWithCommandBuffer_meanAndVarianceState(
984 &self,
985 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
986 mean_and_variance_state: &MPSCNNNormalizationMeanAndVarianceState,
987 );
988 );
989}
990
991/// Methods declared on superclass `MPSKernel`.
992#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
993impl MPSCNNBatchNormalization {
994 extern_methods!(
995 /// Called by NSCoder to decode MPSKernels
996 ///
997 /// This isn't the right interface to decode a MPSKernel, but
998 /// it is the one that NSCoder uses. To enable your NSCoder
999 /// (e.g. NSKeyedUnarchiver) to set which device to use
1000 /// extend the object to adopt the MPSDeviceProvider
1001 /// protocol. Otherwise, the Metal system default device
1002 /// will be used.
1003 ///
1004 /// # Safety
1005 ///
1006 /// `a_decoder` possibly has further requirements.
1007 #[unsafe(method(initWithCoder:))]
1008 #[unsafe(method_family = init)]
1009 pub unsafe fn initWithCoder(
1010 this: Allocated<Self>,
1011 a_decoder: &NSCoder,
1012 ) -> Option<Retained<Self>>;
1013 );
1014}
1015
1016/// Methods declared on superclass `NSObject`.
1017#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1018impl MPSCNNBatchNormalization {
1019 extern_methods!(
1020 #[unsafe(method(init))]
1021 #[unsafe(method_family = init)]
1022 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1023
1024 #[unsafe(method(new))]
1025 #[unsafe(method_family = new)]
1026 pub unsafe fn new() -> Retained<Self>;
1027 );
1028}
1029
1030extern_class!(
1031 /// Dependencies: This depends on Metal.framework
1032 ///
1033 /// MPSCNNBatchNormalizationStatistics updates a MPSCNNBatchNormalizationState
1034 /// with the batch statistics necessary to perform a batch normalization.
1035 /// MPSCNNBatchNormalizationStatistics may be executed multiple times with
1036 /// multiple images to accumulate all the statistics necessary to perform
1037 /// a batch normalization as described in https://arxiv.org/pdf/1502.03167v3.pdf.
1038 ///
1039 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalizationstatistics?language=objc)
1040 #[unsafe(super(MPSCNNKernel, MPSKernel, NSObject))]
1041 #[derive(Debug, PartialEq, Eq, Hash)]
1042 #[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1043 pub struct MPSCNNBatchNormalizationStatistics;
1044);
1045
1046#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1047extern_conformance!(
1048 unsafe impl NSCoding for MPSCNNBatchNormalizationStatistics {}
1049);
1050
1051#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1052extern_conformance!(
1053 unsafe impl NSCopying for MPSCNNBatchNormalizationStatistics {}
1054);
1055
1056#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1057unsafe impl CopyingHelper for MPSCNNBatchNormalizationStatistics {
1058 type Result = Self;
1059}
1060
1061#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1062extern_conformance!(
1063 unsafe impl NSObjectProtocol for MPSCNNBatchNormalizationStatistics {}
1064);
1065
1066#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1067extern_conformance!(
1068 unsafe impl NSSecureCoding for MPSCNNBatchNormalizationStatistics {}
1069);
1070
1071#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1072impl MPSCNNBatchNormalizationStatistics {
1073 extern_methods!(
1074 /// Initialize this kernel on a device.
1075 ///
1076 ///
1077 /// Parameter `device`: The MTLDevice on which to initialize the kernel.
1078 #[unsafe(method(initWithDevice:))]
1079 #[unsafe(method_family = init)]
1080 pub unsafe fn initWithDevice(
1081 this: Allocated<Self>,
1082 device: &ProtocolObject<dyn MTLDevice>,
1083 ) -> Retained<Self>;
1084
1085 /// NSSecureCoding compatability
1086 ///
1087 /// While the standard NSSecureCoding/NSCoding method
1088 /// -initWithCoder: should work, since the file can't
1089 /// know which device your data is allocated on, we
1090 /// have to guess and may guess incorrectly. To avoid
1091 /// that problem, use initWithCoder:device instead.
1092 ///
1093 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1094 ///
1095 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1096 ///
1097 /// Returns: A new MPSCNNBatchNormalizationStatistics object, or nil if failure.
1098 ///
1099 /// # Safety
1100 ///
1101 /// `a_decoder` possibly has further requirements.
1102 #[unsafe(method(initWithCoder:device:))]
1103 #[unsafe(method_family = init)]
1104 pub unsafe fn initWithCoder_device(
1105 this: Allocated<Self>,
1106 a_decoder: &NSCoder,
1107 device: &ProtocolObject<dyn MTLDevice>,
1108 ) -> Option<Retained<Self>>;
1109
1110 #[cfg(all(
1111 feature = "MPSImage",
1112 feature = "MPSNDArray",
1113 feature = "MPSNNGradientState",
1114 feature = "MPSState"
1115 ))]
1116 /// Encode this operation to a command buffer.
1117 ///
1118 /// Parameter `commandBuffer`: The command buffer.
1119 ///
1120 /// Parameter `sourceImages`: An MPSImageBatch containing the source images.
1121 ///
1122 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which
1123 /// will be updated with the image batch statistics.
1124 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:batchNormalizationState:))]
1125 #[unsafe(method_family = none)]
1126 pub unsafe fn encodeBatchToCommandBuffer_sourceImages_batchNormalizationState(
1127 &self,
1128 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1129 source_images: &MPSImageBatch,
1130 batch_normalization_state: &MPSCNNBatchNormalizationState,
1131 );
1132
1133 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray"))]
1134 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:destinationImages:))]
1135 #[unsafe(method_family = none)]
1136 pub unsafe fn encodeBatchToCommandBuffer_sourceImages_destinationImages(
1137 &self,
1138 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1139 source_images: &MPSImageBatch,
1140 destination_images: &MPSImageBatch,
1141 );
1142
1143 #[cfg(feature = "MPSImage")]
1144 #[unsafe(method(encodeToCommandBuffer:sourceImage:destinationImage:))]
1145 #[unsafe(method_family = none)]
1146 pub unsafe fn encodeToCommandBuffer_sourceImage_destinationImage(
1147 &self,
1148 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1149 source_image: &MPSImage,
1150 destination_image: &MPSImage,
1151 );
1152
1153 #[cfg(feature = "MPSImage")]
1154 #[unsafe(method(encodeToCommandBuffer:sourceImage:))]
1155 #[unsafe(method_family = none)]
1156 pub unsafe fn encodeToCommandBuffer_sourceImage(
1157 &self,
1158 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1159 source_image: &MPSImage,
1160 ) -> Retained<MPSImage>;
1161
1162 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray"))]
1163 #[unsafe(method(encodeBatchToCommandBuffer:sourceImages:))]
1164 #[unsafe(method_family = none)]
1165 pub unsafe fn encodeBatchToCommandBuffer_sourceImages(
1166 &self,
1167 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1168 source_images: &MPSImageBatch,
1169 ) -> Retained<MPSImageBatch>;
1170 );
1171}
1172
1173/// Methods declared on superclass `MPSKernel`.
1174#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1175impl MPSCNNBatchNormalizationStatistics {
1176 extern_methods!(
1177 /// Called by NSCoder to decode MPSKernels
1178 ///
1179 /// This isn't the right interface to decode a MPSKernel, but
1180 /// it is the one that NSCoder uses. To enable your NSCoder
1181 /// (e.g. NSKeyedUnarchiver) to set which device to use
1182 /// extend the object to adopt the MPSDeviceProvider
1183 /// protocol. Otherwise, the Metal system default device
1184 /// will be used.
1185 ///
1186 /// # Safety
1187 ///
1188 /// `a_decoder` possibly has further requirements.
1189 #[unsafe(method(initWithCoder:))]
1190 #[unsafe(method_family = init)]
1191 pub unsafe fn initWithCoder(
1192 this: Allocated<Self>,
1193 a_decoder: &NSCoder,
1194 ) -> Option<Retained<Self>>;
1195 );
1196}
1197
1198/// Methods declared on superclass `NSObject`.
1199#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1200impl MPSCNNBatchNormalizationStatistics {
1201 extern_methods!(
1202 #[unsafe(method(init))]
1203 #[unsafe(method_family = init)]
1204 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1205
1206 #[unsafe(method(new))]
1207 #[unsafe(method_family = new)]
1208 pub unsafe fn new() -> Retained<Self>;
1209 );
1210}
1211
1212extern_class!(
1213 /// Dependencies: This depends on Metal.framework
1214 ///
1215 ///
1216 /// MPSCNNBatchNormalizationGradient computes the gradients of a
1217 /// loss function resulting from a network containing a corresponding
1218 /// MPSCNNBatchNormalization kernel.
1219 ///
1220 /// Two sets of values are computed: the gradient of the loss function
1221 /// with respect to the batch normalization source images, and the
1222 /// gradient of the loss function with respect to the scale and bias
1223 /// terms used to compute the batch normalization.
1224 ///
1225 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalizationgradient?language=objc)
1226 #[unsafe(super(MPSCNNGradientKernel, MPSCNNBinaryKernel, MPSKernel, NSObject))]
1227 #[derive(Debug, PartialEq, Eq, Hash)]
1228 #[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1229 pub struct MPSCNNBatchNormalizationGradient;
1230);
1231
1232#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1233extern_conformance!(
1234 unsafe impl NSCoding for MPSCNNBatchNormalizationGradient {}
1235);
1236
1237#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1238extern_conformance!(
1239 unsafe impl NSCopying for MPSCNNBatchNormalizationGradient {}
1240);
1241
1242#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1243unsafe impl CopyingHelper for MPSCNNBatchNormalizationGradient {
1244 type Result = Self;
1245}
1246
1247#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1248extern_conformance!(
1249 unsafe impl NSObjectProtocol for MPSCNNBatchNormalizationGradient {}
1250);
1251
1252#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1253extern_conformance!(
1254 unsafe impl NSSecureCoding for MPSCNNBatchNormalizationGradient {}
1255);
1256
1257#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1258impl MPSCNNBatchNormalizationGradient {
1259 extern_methods!(
1260 #[cfg(feature = "MPSCNNNeuron")]
1261 /// Initializes a batch normalization gradient kernel using a device and neuron descriptor.
1262 ///
1263 /// Parameter `device`: The MTLDevice on which this filter will be used
1264 ///
1265 /// Parameter `fusedNeuronDescriptor`: A MPSNNNeuronDescriptor object which specifies a neuron activation function whose
1266 /// gradient should be applied prior to computing the resulting gradient.
1267 /// This neuron descriptor should match that used in the corresponding forward batch
1268 /// normalization kernel as well as the preceeding batch normalization statistics gradient
1269 /// kernel.
1270 ///
1271 ///
1272 /// Returns: A valid MPSCNNBatchNormalizationGradient object or nil, if failure.
1273 #[unsafe(method(initWithDevice:fusedNeuronDescriptor:))]
1274 #[unsafe(method_family = init)]
1275 pub unsafe fn initWithDevice_fusedNeuronDescriptor(
1276 this: Allocated<Self>,
1277 device: &ProtocolObject<dyn MTLDevice>,
1278 fused_neuron_descriptor: Option<&MPSNNNeuronDescriptor>,
1279 ) -> Retained<Self>;
1280
1281 /// NSSecureCoding compatability
1282 ///
1283 /// While the standard NSSecureCoding/NSCoding method
1284 /// -initWithCoder: should work, since the file can't
1285 /// know which device your data is allocated on, we
1286 /// have to guess and may guess incorrectly. To avoid
1287 /// that problem, use a subclass of NSCoder that
1288 /// implements the
1289 /// <MPSDeviceProvider
1290 /// > protocol to
1291 /// tell MPS the MTLDevice to use.
1292 ///
1293 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1294 ///
1295 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1296 ///
1297 /// Returns: A new MPSCNNBatchNormalizationGradient object, or nil if failure.
1298 ///
1299 /// # Safety
1300 ///
1301 /// `a_decoder` possibly has further requirements.
1302 #[unsafe(method(initWithCoder:device:))]
1303 #[unsafe(method_family = init)]
1304 pub unsafe fn initWithCoder_device(
1305 this: Allocated<Self>,
1306 a_decoder: &NSCoder,
1307 device: &ProtocolObject<dyn MTLDevice>,
1308 ) -> Option<Retained<Self>>;
1309
1310 #[cfg(all(
1311 feature = "MPSImage",
1312 feature = "MPSNNGradientState",
1313 feature = "MPSState"
1314 ))]
1315 /// Encode this operation to a command buffer for a single image.
1316 ///
1317 /// Parameter `commandBuffer`: The command buffer.
1318 ///
1319 /// Parameter `sourceGradient`: An MPSImage containing the gradient of the loss function with
1320 /// respect to the results of batch normalization on the source image.
1321 ///
1322 /// Parameter `sourceImage`: An MPSImage containing the source image for batch normalization.
1323 ///
1324 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which
1325 /// has been previously updated using a MPSCNNBatchNormalizationStatisticsGradient
1326 /// kernel and the source images. If the state is temporary its read count will be decremented.
1327 ///
1328 /// Parameter `destinationGradient`: An MPSImage which contains the gradient of the loss function with respect to the source image.
1329 #[unsafe(method(encodeToCommandBuffer:sourceGradient:sourceImage:batchNormalizationState:destinationGradient:))]
1330 #[unsafe(method_family = none)]
1331 pub unsafe fn encodeToCommandBuffer_sourceGradient_sourceImage_batchNormalizationState_destinationGradient(
1332 &self,
1333 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1334 source_gradient: &MPSImage,
1335 source_image: &MPSImage,
1336 batch_normalization_state: &MPSCNNBatchNormalizationState,
1337 destination_gradient: &MPSImage,
1338 );
1339
1340 #[cfg(all(
1341 feature = "MPSImage",
1342 feature = "MPSNDArray",
1343 feature = "MPSNNGradientState",
1344 feature = "MPSState"
1345 ))]
1346 /// Encode this operation to a command buffer.
1347 ///
1348 /// Parameter `commandBuffer`: The command buffer.
1349 ///
1350 /// Parameter `sourceGradients`: An MPSImageBatch containing the gradient of the
1351 /// loss function with respect to the results of batch normalization
1352 /// on the source images.
1353 ///
1354 /// Parameter `sourceImages`: An MPSImageBatch containing the source images for
1355 /// batch normalization.
1356 ///
1357 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which
1358 /// has been previously updated using a MPSCNNBatchNormalizationStatisticsGradient
1359 /// kernel and the source images. If the state is temporary its read count will be decremented.
1360 ///
1361 /// Parameter `destinationGradients`: An MPSImageBatch whose images will contain the gradient
1362 /// of the loss function with respect to the source images.
1363 #[unsafe(method(encodeBatchToCommandBuffer:sourceGradients:sourceImages:batchNormalizationState:destinationGradients:))]
1364 #[unsafe(method_family = none)]
1365 pub unsafe fn encodeBatchToCommandBuffer_sourceGradients_sourceImages_batchNormalizationState_destinationGradients(
1366 &self,
1367 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1368 source_gradients: &MPSImageBatch,
1369 source_images: &MPSImageBatch,
1370 batch_normalization_state: &MPSCNNBatchNormalizationState,
1371 destination_gradients: &MPSImageBatch,
1372 );
1373
1374 #[cfg(all(
1375 feature = "MPSImage",
1376 feature = "MPSNNGradientState",
1377 feature = "MPSState"
1378 ))]
1379 /// Encode this operation to a command buffer. Create an MPSImage to contain
1380 /// the result and return it.
1381 /// See encodeToCommandBuffer:sourceImage:sourceGradient:sourceImage:batchNormalizationState:destinationGradient
1382 /// for further details.
1383 #[unsafe(method(encodeToCommandBuffer:sourceGradient:sourceImage:batchNormalizationState:))]
1384 #[unsafe(method_family = none)]
1385 pub unsafe fn encodeToCommandBuffer_sourceGradient_sourceImage_batchNormalizationState(
1386 &self,
1387 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1388 source_gradient: &MPSImage,
1389 source_image: &MPSImage,
1390 batch_normalization_state: &MPSCNNBatchNormalizationState,
1391 ) -> Retained<MPSImage>;
1392
1393 #[cfg(all(
1394 feature = "MPSImage",
1395 feature = "MPSNDArray",
1396 feature = "MPSNNGradientState",
1397 feature = "MPSState"
1398 ))]
1399 /// Encode this operation to a command buffer. Create an MPSImageBatch to contain
1400 /// the result and return it.
1401 /// See encodeBatchToCommandBuffer:sourceGradients:sourceImages:batchNormalizationState:destinationGradients
1402 /// for further details.
1403 #[unsafe(method(encodeBatchToCommandBuffer:sourceGradients:sourceImages:batchNormalizationState:))]
1404 #[unsafe(method_family = none)]
1405 pub unsafe fn encodeBatchToCommandBuffer_sourceGradients_sourceImages_batchNormalizationState(
1406 &self,
1407 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1408 source_gradients: &MPSImageBatch,
1409 source_images: &MPSImageBatch,
1410 batch_normalization_state: &MPSCNNBatchNormalizationState,
1411 ) -> Retained<MPSImageBatch>;
1412
1413 #[cfg(feature = "MPSImage")]
1414 #[unsafe(method(encodeToCommandBuffer:primaryImage:secondaryImage:destinationImage:))]
1415 #[unsafe(method_family = none)]
1416 pub unsafe fn encodeToCommandBuffer_primaryImage_secondaryImage_destinationImage(
1417 &self,
1418 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1419 primary_image: &MPSImage,
1420 secondary_image: &MPSImage,
1421 destination_image: &MPSImage,
1422 );
1423
1424 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray"))]
1425 #[unsafe(method(encodeBatchToCommandBuffer:primaryImages:secondaryImages:destinationImages:))]
1426 #[unsafe(method_family = none)]
1427 pub unsafe fn encodeBatchToCommandBuffer_primaryImages_secondaryImages_destinationImages(
1428 &self,
1429 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1430 primary_images: &MPSImageBatch,
1431 secondary_images: &MPSImageBatch,
1432 destination_images: &MPSImageBatch,
1433 );
1434
1435 #[cfg(feature = "MPSImage")]
1436 #[unsafe(method(encodeToCommandBuffer:primaryImage:secondaryImage:))]
1437 #[unsafe(method_family = none)]
1438 pub unsafe fn encodeToCommandBuffer_primaryImage_secondaryImage(
1439 &self,
1440 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1441 primary_image: &MPSImage,
1442 secondary_image: &MPSImage,
1443 ) -> Retained<MPSImage>;
1444
1445 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray"))]
1446 #[unsafe(method(encodeBatchToCommandBuffer:primaryImages:secondaryImages:))]
1447 #[unsafe(method_family = none)]
1448 pub unsafe fn encodeBatchToCommandBuffer_primaryImages_secondaryImages(
1449 &self,
1450 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1451 primary_image: &MPSImageBatch,
1452 secondary_image: &MPSImageBatch,
1453 ) -> Retained<MPSImageBatch>;
1454 );
1455}
1456
1457/// Methods declared on superclass `MPSCNNGradientKernel`.
1458#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1459impl MPSCNNBatchNormalizationGradient {
1460 extern_methods!(
1461 /// Standard init with default properties per filter type
1462 ///
1463 /// Parameter `device`: The device that the filter will be used on. May not be NULL.
1464 ///
1465 /// Returns: A pointer to the newly initialized object. This will fail, returning
1466 /// nil if the device is not supported. Devices must be
1467 /// MTLFeatureSet_iOS_GPUFamily2_v1 or later.
1468 #[unsafe(method(initWithDevice:))]
1469 #[unsafe(method_family = init)]
1470 pub unsafe fn initWithDevice(
1471 this: Allocated<Self>,
1472 device: &ProtocolObject<dyn MTLDevice>,
1473 ) -> Retained<Self>;
1474 );
1475}
1476
1477/// Methods declared on superclass `MPSKernel`.
1478#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1479impl MPSCNNBatchNormalizationGradient {
1480 extern_methods!(
1481 /// Called by NSCoder to decode MPSKernels
1482 ///
1483 /// This isn't the right interface to decode a MPSKernel, but
1484 /// it is the one that NSCoder uses. To enable your NSCoder
1485 /// (e.g. NSKeyedUnarchiver) to set which device to use
1486 /// extend the object to adopt the MPSDeviceProvider
1487 /// protocol. Otherwise, the Metal system default device
1488 /// will be used.
1489 ///
1490 /// # Safety
1491 ///
1492 /// `a_decoder` possibly has further requirements.
1493 #[unsafe(method(initWithCoder:))]
1494 #[unsafe(method_family = init)]
1495 pub unsafe fn initWithCoder(
1496 this: Allocated<Self>,
1497 a_decoder: &NSCoder,
1498 ) -> Option<Retained<Self>>;
1499 );
1500}
1501
1502/// Methods declared on superclass `NSObject`.
1503#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1504impl MPSCNNBatchNormalizationGradient {
1505 extern_methods!(
1506 #[unsafe(method(init))]
1507 #[unsafe(method_family = init)]
1508 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1509
1510 #[unsafe(method(new))]
1511 #[unsafe(method_family = new)]
1512 pub unsafe fn new() -> Retained<Self>;
1513 );
1514}
1515
1516extern_class!(
1517 /// Dependencies: This depends on Metal.framework
1518 ///
1519 /// MPSCNNBatchNormalizationStatisticsGradient updates a MPSCNNBatchNormalizationState
1520 /// with the gradient of the loss function with respect to the batch statistics and
1521 /// batch normalization weights used to perform a batch normalization.
1522 ///
1523 /// See also [Apple's documentation](https://developer.apple.com/documentation/metalperformanceshaders/mpscnnbatchnormalizationstatisticsgradient?language=objc)
1524 #[unsafe(super(MPSCNNGradientKernel, MPSCNNBinaryKernel, MPSKernel, NSObject))]
1525 #[derive(Debug, PartialEq, Eq, Hash)]
1526 #[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1527 pub struct MPSCNNBatchNormalizationStatisticsGradient;
1528);
1529
1530#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1531extern_conformance!(
1532 unsafe impl NSCoding for MPSCNNBatchNormalizationStatisticsGradient {}
1533);
1534
1535#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1536extern_conformance!(
1537 unsafe impl NSCopying for MPSCNNBatchNormalizationStatisticsGradient {}
1538);
1539
1540#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1541unsafe impl CopyingHelper for MPSCNNBatchNormalizationStatisticsGradient {
1542 type Result = Self;
1543}
1544
1545#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1546extern_conformance!(
1547 unsafe impl NSObjectProtocol for MPSCNNBatchNormalizationStatisticsGradient {}
1548);
1549
1550#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1551extern_conformance!(
1552 unsafe impl NSSecureCoding for MPSCNNBatchNormalizationStatisticsGradient {}
1553);
1554
1555#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1556impl MPSCNNBatchNormalizationStatisticsGradient {
1557 extern_methods!(
1558 #[cfg(feature = "MPSCNNNeuron")]
1559 /// Initializes a batch normalization statistics gradient kernel using a device and neuron descriptor.
1560 ///
1561 /// Parameter `device`: The MTLDevice on which this filter will be used
1562 ///
1563 /// Parameter `fusedNeuronDescriptor`: A MPSNNNeuronDescriptor object which specifies a neuron activation function whose
1564 /// gradient should be applied prior to computing the statistics of the input gradient.
1565 /// This neuron descriptor should match that used in the corresponding forward batch
1566 /// normalization kernel.
1567 ///
1568 ///
1569 /// Returns: A valid MPSCNNBatchNormalizationStatisticsGradient object or nil, if failure.
1570 #[unsafe(method(initWithDevice:fusedNeuronDescriptor:))]
1571 #[unsafe(method_family = init)]
1572 pub unsafe fn initWithDevice_fusedNeuronDescriptor(
1573 this: Allocated<Self>,
1574 device: &ProtocolObject<dyn MTLDevice>,
1575 fused_neuron_descriptor: Option<&MPSNNNeuronDescriptor>,
1576 ) -> Retained<Self>;
1577
1578 /// NSSecureCoding compatability
1579 ///
1580 /// While the standard NSSecureCoding/NSCoding method
1581 /// -initWithCoder: should work, since the file can't
1582 /// know which device your data is allocated on, we
1583 /// have to guess and may guess incorrectly. To avoid
1584 /// that problem, use a subclass of NSCoder that
1585 /// implements the
1586 /// <MPSDeviceProvider
1587 /// > protocol to
1588 /// tell MPS the MTLDevice to use.
1589 ///
1590 /// Parameter `aDecoder`: The NSCoder subclass with your serialized MPSKernel
1591 ///
1592 /// Parameter `device`: The MTLDevice on which to make the MPSKernel
1593 ///
1594 /// Returns: A new MPSCNNBatchNormalizationStatisticsGradient object, or nil if failure.
1595 ///
1596 /// # Safety
1597 ///
1598 /// `a_decoder` possibly has further requirements.
1599 #[unsafe(method(initWithCoder:device:))]
1600 #[unsafe(method_family = init)]
1601 pub unsafe fn initWithCoder_device(
1602 this: Allocated<Self>,
1603 a_decoder: &NSCoder,
1604 device: &ProtocolObject<dyn MTLDevice>,
1605 ) -> Option<Retained<Self>>;
1606
1607 #[cfg(all(
1608 feature = "MPSImage",
1609 feature = "MPSNDArray",
1610 feature = "MPSNNGradientState",
1611 feature = "MPSState"
1612 ))]
1613 /// Encode this operation to a command buffer.
1614 ///
1615 /// Parameter `commandBuffer`: The command buffer.
1616 ///
1617 /// Parameter `sourceGradients`: An MPSImageBatch containing the gradient of the
1618 /// loss function with respect to the results of batch normalization
1619 /// on the source images.
1620 ///
1621 /// Parameter `sourceImages`: An MPSImageBatch containing the source images for
1622 /// batch normalization.
1623 ///
1624 /// Parameter `batchNormalizationState`: A valid MPSCNNBatchNormalizationState object which
1625 /// has been previously updated using a MPSCNNBatchNormalizationStatistics
1626 /// kernel and the source images. Upon completion of the
1627 /// command buffer, will contain the (possibly partially updated)
1628 /// gradients for the loss function with respect to the scale and
1629 /// bias parameters used to compute the batch normalization. The
1630 /// state will be considered to be completely updated when all
1631 /// MPSImages in the training batch have been processed. If the state
1632 /// is temporary its read count will be decremented.
1633 #[unsafe(method(encodeBatchToCommandBuffer:sourceGradients:sourceImages:batchNormalizationState:))]
1634 #[unsafe(method_family = none)]
1635 pub unsafe fn encodeBatchToCommandBuffer_sourceGradients_sourceImages_batchNormalizationState(
1636 &self,
1637 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1638 source_gradients: &MPSImageBatch,
1639 source_images: &MPSImageBatch,
1640 batch_normalization_state: &MPSCNNBatchNormalizationState,
1641 );
1642
1643 #[cfg(all(feature = "MPSImage", feature = "MPSState"))]
1644 #[unsafe(method(encodeToCommandBuffer:sourceGradient:sourceImage:gradientState:))]
1645 #[unsafe(method_family = none)]
1646 pub unsafe fn encodeToCommandBuffer_sourceGradient_sourceImage_gradientState(
1647 &self,
1648 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1649 source_gradient: &MPSImage,
1650 source_image: &MPSImage,
1651 gradient_state: &MPSState,
1652 ) -> Retained<MPSImage>;
1653
1654 #[cfg(all(feature = "MPSImage", feature = "MPSState"))]
1655 #[unsafe(method(encodeToCommandBuffer:sourceGradient:sourceImage:gradientState:destinationGradient:))]
1656 #[unsafe(method_family = none)]
1657 pub unsafe fn encodeToCommandBuffer_sourceGradient_sourceImage_gradientState_destinationGradient(
1658 &self,
1659 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1660 source_gradient: &MPSImage,
1661 source_image: &MPSImage,
1662 gradient_state: &MPSState,
1663 destination_gradient: &MPSImage,
1664 );
1665
1666 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray", feature = "MPSState"))]
1667 #[unsafe(method(encodeBatchToCommandBuffer:sourceGradients:sourceImages:gradientStates:))]
1668 #[unsafe(method_family = none)]
1669 pub unsafe fn encodeBatchToCommandBuffer_sourceGradients_sourceImages_gradientStates(
1670 &self,
1671 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1672 source_gradients: &MPSImageBatch,
1673 source_images: &MPSImageBatch,
1674 gradient_states: &MPSStateBatch,
1675 ) -> Retained<MPSImageBatch>;
1676
1677 #[cfg(all(feature = "MPSImage", feature = "MPSNDArray", feature = "MPSState"))]
1678 #[unsafe(method(encodeBatchToCommandBuffer:sourceGradients:sourceImages:gradientStates:destinationGradients:))]
1679 #[unsafe(method_family = none)]
1680 pub unsafe fn encodeBatchToCommandBuffer_sourceGradients_sourceImages_gradientStates_destinationGradients(
1681 &self,
1682 command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
1683 source_gradients: &MPSImageBatch,
1684 source_images: &MPSImageBatch,
1685 gradient_states: &MPSStateBatch,
1686 destination_gradients: &MPSImageBatch,
1687 );
1688 );
1689}
1690
1691/// Methods declared on superclass `MPSCNNGradientKernel`.
1692#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1693impl MPSCNNBatchNormalizationStatisticsGradient {
1694 extern_methods!(
1695 /// Standard init with default properties per filter type
1696 ///
1697 /// Parameter `device`: The device that the filter will be used on. May not be NULL.
1698 ///
1699 /// Returns: A pointer to the newly initialized object. This will fail, returning
1700 /// nil if the device is not supported. Devices must be
1701 /// MTLFeatureSet_iOS_GPUFamily2_v1 or later.
1702 #[unsafe(method(initWithDevice:))]
1703 #[unsafe(method_family = init)]
1704 pub unsafe fn initWithDevice(
1705 this: Allocated<Self>,
1706 device: &ProtocolObject<dyn MTLDevice>,
1707 ) -> Retained<Self>;
1708 );
1709}
1710
1711/// Methods declared on superclass `MPSKernel`.
1712#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1713impl MPSCNNBatchNormalizationStatisticsGradient {
1714 extern_methods!(
1715 /// Called by NSCoder to decode MPSKernels
1716 ///
1717 /// This isn't the right interface to decode a MPSKernel, but
1718 /// it is the one that NSCoder uses. To enable your NSCoder
1719 /// (e.g. NSKeyedUnarchiver) to set which device to use
1720 /// extend the object to adopt the MPSDeviceProvider
1721 /// protocol. Otherwise, the Metal system default device
1722 /// will be used.
1723 ///
1724 /// # Safety
1725 ///
1726 /// `a_decoder` possibly has further requirements.
1727 #[unsafe(method(initWithCoder:))]
1728 #[unsafe(method_family = init)]
1729 pub unsafe fn initWithCoder(
1730 this: Allocated<Self>,
1731 a_decoder: &NSCoder,
1732 ) -> Option<Retained<Self>>;
1733 );
1734}
1735
1736/// Methods declared on superclass `NSObject`.
1737#[cfg(all(feature = "MPSCNNKernel", feature = "MPSCore", feature = "MPSKernel"))]
1738impl MPSCNNBatchNormalizationStatisticsGradient {
1739 extern_methods!(
1740 #[unsafe(method(init))]
1741 #[unsafe(method_family = init)]
1742 pub unsafe fn init(this: Allocated<Self>) -> Retained<Self>;
1743
1744 #[unsafe(method(new))]
1745 #[unsafe(method_family = new)]
1746 pub unsafe fn new() -> Retained<Self>;
1747 );
1748}