pub unsafe trait MPSCNNConvolutionDataSource: NSCopying + NSObjectProtocol {
Show 14 methods
// Provided methods
unsafe fn dataType(&self) -> MPSDataType
where Self: Sized + Message { ... }
unsafe fn descriptor(&self) -> Retained<MPSCNNConvolutionDescriptor>
where Self: Sized + Message { ... }
unsafe fn weights(&self) -> NonNull<c_void>
where Self: Sized + Message { ... }
unsafe fn biasTerms(&self) -> *mut c_float
where Self: Sized + Message { ... }
unsafe fn load(&self) -> bool
where Self: Sized + Message { ... }
unsafe fn purge(&self)
where Self: Sized + Message { ... }
unsafe fn label(&self) -> Option<Retained<NSString>>
where Self: Sized + Message { ... }
unsafe fn lookupTableForUInt8Kernel(&self) -> NonNull<c_float>
where Self: Sized + Message { ... }
unsafe fn weightsQuantizationType(&self) -> MPSCNNWeightsQuantizationType
where Self: Sized + Message { ... }
unsafe fn updateWithCommandBuffer_gradientState_sourceState(
&self,
command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
gradient_state: &MPSCNNConvolutionGradientState,
source_state: &MPSCNNConvolutionWeightsAndBiasesState,
) -> Option<Retained<MPSCNNConvolutionWeightsAndBiasesState>>
where Self: Sized + Message { ... }
unsafe fn updateWithGradientState_sourceState(
&self,
gradient_state: &MPSCNNConvolutionGradientState,
source_state: &MPSCNNConvolutionWeightsAndBiasesState,
) -> bool
where Self: Sized + Message { ... }
unsafe fn copyWithZone_device(
&self,
zone: *mut NSZone,
device: Option<&ProtocolObject<dyn MTLDevice>>,
) -> Retained<Self>
where Self: Sized + Message { ... }
unsafe fn weightsLayout(&self) -> MPSCNNConvolutionWeightsLayout
where Self: Sized + Message { ... }
unsafe fn kernelWeightsDataType(&self) -> MPSDataType
where Self: Sized + Message { ... }
}MPSCNNConvolution only.Expand description
Provides convolution filter weights and bias terms
The MPSCNNConvolutionDataSource protocol declares the methods that an instance of MPSCNNConvolution uses to obtain the weights and bias terms for the CNN convolution filter.
Why? CNN weights can be large. If multiple copies of all the weights for all the convolutions are available unpacked in memory at the same time, some devices can run out of memory. The MPSCNNConvolutionDataSource is used to encapsulate a reference to the weights such as a file path, so that unpacking can be deferred until needed, then purged soon thereafter so that not all of the data must be in memory at the same time. MPS does not provide a class that conforms to this protocol. It is up to the developer to craft his own to encapsulate his data.
Batch normalization and the neuron activation function are handled using the -descriptor method.
Thread safety: The MPSCNNConvolutionDataSource object can be called by threads that are not the main thread. If you will be creating multiple MPSNNGraph objects concurrently in multiple threads and these share MPSCNNConvolutionDataSources, then the data source objects may be called reentrantly.
See also Apple’s documentation
Provided Methods§
Sourceunsafe fn dataType(&self) -> MPSDataType
Available on crate features MPSCore and MPSCoreTypes only.
unsafe fn dataType(&self) -> MPSDataType
MPSCore and MPSCoreTypes only.Alerts MPS what sort of weights are provided by the object
For MPSCNNConvolution, MPSDataTypeUInt8, MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for normal convolutions using MPSCNNConvolution. MPSCNNBinaryConvolution assumes weights to be of type MPSDataTypeUInt32 always.
Sourceunsafe fn descriptor(&self) -> Retained<MPSCNNConvolutionDescriptor>
unsafe fn descriptor(&self) -> Retained<MPSCNNConvolutionDescriptor>
Return a MPSCNNConvolutionDescriptor as needed
MPS will not modify this object other than perhaps to retain it. User should set the appropriate neuron in the creation of convolution descriptor and for batch normalization use:
-setBatchNormalizationParametersForInferenceWithMean:variance:gamma:beta:epsilon:Returns: A MPSCNNConvolutionDescriptor that describes the kernel housed by this object.
Sourceunsafe fn weights(&self) -> NonNull<c_void>
unsafe fn weights(&self) -> NonNull<c_void>
Returns a pointer to the weights for the convolution.
The type of each entry in array is given by -dataType. The number of entries is equal to:
inputFeatureChannels * outputFeatureChannels * kernelHeight * kernelWidthThe layout of filter weight is as a 4D tensor (array) weight[ outputChannels ][ kernelHeight ][ kernelWidth ][ inputChannels / groups ]
Frequently, this function is a single line of code to return a pointer to memory allocated in -load.
Batch normalization parameters are set using -descriptor.
Note: For binary-convolutions the layout of the weights are: weight[ outputChannels ][ kernelHeight ][ kernelWidth ][ floor((inputChannels/groups)+31) / 32 ] with each 32 sub input feature channel index specified in machine byte order, so that for example the 13th feature channel bit can be extracted using bitmask = (1U < < 13).
Sourceunsafe fn biasTerms(&self) -> *mut c_float
unsafe fn biasTerms(&self) -> *mut c_float
Returns a pointer to the bias terms for the convolution.
Each entry in the array is a single precision IEEE-754 float and represents one bias. The number of entries is equal to outputFeatureChannels.
Frequently, this function is a single line of code to return a pointer to memory allocated in -load. It may also just return nil.
Note: bias terms are always float, even when the weights are not.
Sourceunsafe fn load(&self) -> bool
unsafe fn load(&self) -> bool
Alerts the data source that the data will be needed soon
Each load alert will be balanced by a purge later, when MPS no longer needs the data from this object. Load will always be called atleast once after initial construction or each purge of the object before anything else is called. Note: load may be called to merely inspect the descriptor. In some circumstances, it may be worthwhile to postpone weight and bias construction until they are actually needed to save touching memory and keep the working set small. The load function is intended to be an opportunity to open files or mark memory no longer purgeable.
Returns: Returns YES on success. If NO is returned, expect MPS object construction to fail.
Sourceunsafe fn purge(&self)
unsafe fn purge(&self)
Alerts the data source that the data is no longer needed
Each load alert will be balanced by a purge later, when MPS no longer needs the data from this object.
Sourceunsafe fn label(&self) -> Option<Retained<NSString>>
unsafe fn label(&self) -> Option<Retained<NSString>>
A label that is transferred to the convolution at init time
Overridden by a MPSCNNConvolutionNode.label if it is non-nil.
Sourceunsafe fn lookupTableForUInt8Kernel(&self) -> NonNull<c_float>
unsafe fn lookupTableForUInt8Kernel(&self) -> NonNull<c_float>
A pointer to a 256 entry lookup table containing the values to use for the weight range [0,255]
Sourceunsafe fn weightsQuantizationType(&self) -> MPSCNNWeightsQuantizationType
unsafe fn weightsQuantizationType(&self) -> MPSCNNWeightsQuantizationType
Quantizaiton type of weights. If it returns MPSCNNWeightsQuantizationTypeLookupTable, lookupTableForUInt8Kernel method must be implmented. if it returns MPSCNNWeightsQuantizationTypeLookupLinear, rangesForUInt8Kernel method must be implemented.
Sourceunsafe fn updateWithCommandBuffer_gradientState_sourceState(
&self,
command_buffer: &ProtocolObject<dyn MTLCommandBuffer>,
gradient_state: &MPSCNNConvolutionGradientState,
source_state: &MPSCNNConvolutionWeightsAndBiasesState,
) -> Option<Retained<MPSCNNConvolutionWeightsAndBiasesState>>
Available on crate features MPSCore and MPSNNGradientState and MPSState only.
unsafe fn updateWithCommandBuffer_gradientState_sourceState( &self, command_buffer: &ProtocolObject<dyn MTLCommandBuffer>, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> Option<Retained<MPSCNNConvolutionWeightsAndBiasesState>>
MPSCore and MPSNNGradientState and MPSState only.Callback for the MPSNNGraph to update the convolution weights on GPU.
It is the resposibility of this method to decrement the read count of both the gradientState and the sourceState before returning. BUG: prior to macOS 10.14, ios/tvos 12.0, the MPSNNGraph incorrectly decrements the readcount of the gradientState after this method is called.
Parameter commandBuffer: The command buffer on which to do the update.
MPSCNNConvolutionGradientNode.MPSNNTrainingStyle controls where you want your update
to happen. Provide implementation of this function for GPU side update.
Parameter gradientState: A state object produced by the MPSCNNConvolution and updated by MPSCNNConvolutionGradient
containing weight gradients.
Parameter sourceState: A state object containing the convolution weights
Returns: If NULL, no update occurs. If nonnull, the result will be used to update the weights in the MPSNNGraph
Sourceunsafe fn updateWithGradientState_sourceState(
&self,
gradient_state: &MPSCNNConvolutionGradientState,
source_state: &MPSCNNConvolutionWeightsAndBiasesState,
) -> bool
Available on crate features MPSCore and MPSNNGradientState and MPSState only.
unsafe fn updateWithGradientState_sourceState( &self, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> bool
MPSCore and MPSNNGradientState and MPSState only.Callback for the MPSNNGraph to update the convolution weights on CPU. MPSCNNConvolutionGradientNode.MPSNNTrainingStyle controls where you want your update to happen. Provide implementation of this function for CPU side update.
Parameter gradientState: A state object produced by the MPSCNNConvolution and updated by MPSCNNConvolutionGradient
containing weight gradients. MPSNNGraph is responsible for calling [gradientState synchronizeOnCommandBuffer:]
so that application get correct gradients for CPU side update.
Parameter sourceState: A state object containing the convolution weights used. MPSCNNConvolution and MPSCNNConvolutionGradient reloadWeightsWithDataSource
will be called right after this method is called. Note that the weights returned here may not match the weights
in your data source due to conversion loss. These are the weights actually used, and should
be what you use to calculate the new weights. Your copy may be incorrect. Write the new weights
to your copy and return them out the left hand side.
Returns: TRUE if success/no error, FALSE in case of failure.
Sourceunsafe fn copyWithZone_device(
&self,
zone: *mut NSZone,
device: Option<&ProtocolObject<dyn MTLDevice>>,
) -> Retained<Self>
unsafe fn copyWithZone_device( &self, zone: *mut NSZone, device: Option<&ProtocolObject<dyn MTLDevice>>, ) -> Retained<Self>
When copyWithZone:device on convolution is called, data source copyWithZone:device will be called if data source object responds to this selector. If not, copyWithZone: will be called if data source responds to it. Otherwise, it is simply retained. This is to allow application to make a separate copy of data source in convolution when convolution itself is coplied, for example when copying training graph for running on second GPU so that weights update on two different GPUs dont end up stomping same data source.
§Safety
zone must be a valid pointer or null.
Sourceunsafe fn weightsLayout(&self) -> MPSCNNConvolutionWeightsLayout
unsafe fn weightsLayout(&self) -> MPSCNNConvolutionWeightsLayout
Layout of weights returned by data source. Currently only OHWI layout is supported which is default. See MPSCNNConvolutionWeightsLayout above.
Sourceunsafe fn kernelWeightsDataType(&self) -> MPSDataType
Available on crate features MPSCore and MPSCoreTypes only.
unsafe fn kernelWeightsDataType(&self) -> MPSDataType
MPSCore and MPSCoreTypes only.Alerts MPS what weight precision to use in the CNNConvolution kernel
If precision of weights returned by dataType does not match precision returned by kernelWeightsDataType, weights are converted to precision specified by kernelWeightsDataType before being passed to kernel. For MPSCNNConvolution, dataType precisions of MPSDataTypeUInt8 or MPSDataTypeFloat16 must return a kernelWeightsDataType of MPSDataTypeFloat16. dataType precisions of MPSDataTypeFloat32 may return kernelWeightsDataType of MPSDataTypeFloat16 or MPSDataTypeFloat32. When kernelWeightsDataType returns MPSDataTypeFloat32 the accumulatorPrecisionOption on the CNNConvolution object must be set to MPSNNConvolutionAccumulatorPrecisionOptionFloat. When kernelWeightsDataType is unimplemented the kernel will use float16 precision. MPSCNNBinaryConvolution assumes weights to be of type MPSDataTypeUInt32 always, and the kernelWeightsDataType is unused.
Trait Implementations§
Source§impl ProtocolType for dyn MPSCNNConvolutionDataSource
Available on crate feature MPSNeuralNetwork only.
impl ProtocolType for dyn MPSCNNConvolutionDataSource
MPSNeuralNetwork only.impl<T> ImplementedBy<T> for dyn MPSCNNConvolutionDataSource
MPSNeuralNetwork only.Implementations on Foreign Types§
impl<T> MPSCNNConvolutionDataSource for ProtocolObject<T>where
T: ?Sized + MPSCNNConvolutionDataSource,
MPSNeuralNetwork only.