MPSCNNConvolutionDataSource

Trait MPSCNNConvolutionDataSource 

Source
pub unsafe trait MPSCNNConvolutionDataSource: NSCopying + NSObjectProtocol {
Show 14 methods // Provided methods unsafe fn dataType(&self) -> MPSDataType where Self: Sized + Message { ... } unsafe fn descriptor(&self) -> Retained<MPSCNNConvolutionDescriptor> where Self: Sized + Message { ... } unsafe fn weights(&self) -> NonNull<c_void> where Self: Sized + Message { ... } unsafe fn biasTerms(&self) -> *mut c_float where Self: Sized + Message { ... } unsafe fn load(&self) -> bool where Self: Sized + Message { ... } unsafe fn purge(&self) where Self: Sized + Message { ... } unsafe fn label(&self) -> Option<Retained<NSString>> where Self: Sized + Message { ... } unsafe fn lookupTableForUInt8Kernel(&self) -> NonNull<c_float> where Self: Sized + Message { ... } unsafe fn weightsQuantizationType(&self) -> MPSCNNWeightsQuantizationType where Self: Sized + Message { ... } unsafe fn updateWithCommandBuffer_gradientState_sourceState( &self, command_buffer: &ProtocolObject<dyn MTLCommandBuffer>, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> Option<Retained<MPSCNNConvolutionWeightsAndBiasesState>> where Self: Sized + Message { ... } unsafe fn updateWithGradientState_sourceState( &self, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> bool where Self: Sized + Message { ... } unsafe fn copyWithZone_device( &self, zone: *mut NSZone, device: Option<&ProtocolObject<dyn MTLDevice>>, ) -> Retained<Self> where Self: Sized + Message { ... } unsafe fn weightsLayout(&self) -> MPSCNNConvolutionWeightsLayout where Self: Sized + Message { ... } unsafe fn kernelWeightsDataType(&self) -> MPSDataType where Self: Sized + Message { ... }
}
Available on crate feature MPSCNNConvolution only.
Expand description

Provides convolution filter weights and bias terms

The MPSCNNConvolutionDataSource protocol declares the methods that an instance of MPSCNNConvolution uses to obtain the weights and bias terms for the CNN convolution filter.

Why? CNN weights can be large. If multiple copies of all the weights for all the convolutions are available unpacked in memory at the same time, some devices can run out of memory. The MPSCNNConvolutionDataSource is used to encapsulate a reference to the weights such as a file path, so that unpacking can be deferred until needed, then purged soon thereafter so that not all of the data must be in memory at the same time. MPS does not provide a class that conforms to this protocol. It is up to the developer to craft his own to encapsulate his data.

Batch normalization and the neuron activation function are handled using the -descriptor method.

Thread safety: The MPSCNNConvolutionDataSource object can be called by threads that are not the main thread. If you will be creating multiple MPSNNGraph objects concurrently in multiple threads and these share MPSCNNConvolutionDataSources, then the data source objects may be called reentrantly.

See also Apple’s documentation

Provided Methods§

Source

unsafe fn dataType(&self) -> MPSDataType
where Self: Sized + Message,

Available on crate features MPSCore and MPSCoreTypes only.

Alerts MPS what sort of weights are provided by the object

For MPSCNNConvolution, MPSDataTypeUInt8, MPSDataTypeFloat16 and MPSDataTypeFloat32 are supported for normal convolutions using MPSCNNConvolution. MPSCNNBinaryConvolution assumes weights to be of type MPSDataTypeUInt32 always.

Source

unsafe fn descriptor(&self) -> Retained<MPSCNNConvolutionDescriptor>
where Self: Sized + Message,

Return a MPSCNNConvolutionDescriptor as needed

MPS will not modify this object other than perhaps to retain it. User should set the appropriate neuron in the creation of convolution descriptor and for batch normalization use:

        
              -setBatchNormalizationParametersForInferenceWithMean:variance:gamma:beta:epsilon:

Returns: A MPSCNNConvolutionDescriptor that describes the kernel housed by this object.

Source

unsafe fn weights(&self) -> NonNull<c_void>
where Self: Sized + Message,

Returns a pointer to the weights for the convolution.

The type of each entry in array is given by -dataType. The number of entries is equal to:

                  inputFeatureChannels * outputFeatureChannels * kernelHeight * kernelWidth

The layout of filter weight is as a 4D tensor (array) weight[ outputChannels ][ kernelHeight ][ kernelWidth ][ inputChannels / groups ]

Frequently, this function is a single line of code to return a pointer to memory allocated in -load.

Batch normalization parameters are set using -descriptor.

Note: For binary-convolutions the layout of the weights are: weight[ outputChannels ][ kernelHeight ][ kernelWidth ][ floor((inputChannels/groups)+31) / 32 ] with each 32 sub input feature channel index specified in machine byte order, so that for example the 13th feature channel bit can be extracted using bitmask = (1U < < 13).

Source

unsafe fn biasTerms(&self) -> *mut c_float
where Self: Sized + Message,

Returns a pointer to the bias terms for the convolution.

Each entry in the array is a single precision IEEE-754 float and represents one bias. The number of entries is equal to outputFeatureChannels.

Frequently, this function is a single line of code to return a pointer to memory allocated in -load. It may also just return nil.

Note: bias terms are always float, even when the weights are not.

Source

unsafe fn load(&self) -> bool
where Self: Sized + Message,

Alerts the data source that the data will be needed soon

Each load alert will be balanced by a purge later, when MPS no longer needs the data from this object. Load will always be called atleast once after initial construction or each purge of the object before anything else is called. Note: load may be called to merely inspect the descriptor. In some circumstances, it may be worthwhile to postpone weight and bias construction until they are actually needed to save touching memory and keep the working set small. The load function is intended to be an opportunity to open files or mark memory no longer purgeable.

Returns: Returns YES on success. If NO is returned, expect MPS object construction to fail.

Source

unsafe fn purge(&self)
where Self: Sized + Message,

Alerts the data source that the data is no longer needed

Each load alert will be balanced by a purge later, when MPS no longer needs the data from this object.

Source

unsafe fn label(&self) -> Option<Retained<NSString>>
where Self: Sized + Message,

A label that is transferred to the convolution at init time

Overridden by a MPSCNNConvolutionNode.label if it is non-nil.

Source

unsafe fn lookupTableForUInt8Kernel(&self) -> NonNull<c_float>
where Self: Sized + Message,

A pointer to a 256 entry lookup table containing the values to use for the weight range [0,255]

Source

unsafe fn weightsQuantizationType(&self) -> MPSCNNWeightsQuantizationType
where Self: Sized + Message,

Quantizaiton type of weights. If it returns MPSCNNWeightsQuantizationTypeLookupTable, lookupTableForUInt8Kernel method must be implmented. if it returns MPSCNNWeightsQuantizationTypeLookupLinear, rangesForUInt8Kernel method must be implemented.

Source

unsafe fn updateWithCommandBuffer_gradientState_sourceState( &self, command_buffer: &ProtocolObject<dyn MTLCommandBuffer>, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> Option<Retained<MPSCNNConvolutionWeightsAndBiasesState>>
where Self: Sized + Message,

Available on crate features MPSCore and MPSNNGradientState and MPSState only.

Callback for the MPSNNGraph to update the convolution weights on GPU.

It is the resposibility of this method to decrement the read count of both the gradientState and the sourceState before returning. BUG: prior to macOS 10.14, ios/tvos 12.0, the MPSNNGraph incorrectly decrements the readcount of the gradientState after this method is called.

Parameter commandBuffer: The command buffer on which to do the update. MPSCNNConvolutionGradientNode.MPSNNTrainingStyle controls where you want your update to happen. Provide implementation of this function for GPU side update.

Parameter gradientState: A state object produced by the MPSCNNConvolution and updated by MPSCNNConvolutionGradient containing weight gradients.

Parameter sourceState: A state object containing the convolution weights

Returns: If NULL, no update occurs. If nonnull, the result will be used to update the weights in the MPSNNGraph

Source

unsafe fn updateWithGradientState_sourceState( &self, gradient_state: &MPSCNNConvolutionGradientState, source_state: &MPSCNNConvolutionWeightsAndBiasesState, ) -> bool
where Self: Sized + Message,

Available on crate features MPSCore and MPSNNGradientState and MPSState only.

Callback for the MPSNNGraph to update the convolution weights on CPU. MPSCNNConvolutionGradientNode.MPSNNTrainingStyle controls where you want your update to happen. Provide implementation of this function for CPU side update.

Parameter gradientState: A state object produced by the MPSCNNConvolution and updated by MPSCNNConvolutionGradient containing weight gradients. MPSNNGraph is responsible for calling [gradientState synchronizeOnCommandBuffer:] so that application get correct gradients for CPU side update.

Parameter sourceState: A state object containing the convolution weights used. MPSCNNConvolution and MPSCNNConvolutionGradient reloadWeightsWithDataSource will be called right after this method is called. Note that the weights returned here may not match the weights in your data source due to conversion loss. These are the weights actually used, and should be what you use to calculate the new weights. Your copy may be incorrect. Write the new weights to your copy and return them out the left hand side.

Returns: TRUE if success/no error, FALSE in case of failure.

Source

unsafe fn copyWithZone_device( &self, zone: *mut NSZone, device: Option<&ProtocolObject<dyn MTLDevice>>, ) -> Retained<Self>
where Self: Sized + Message,

When copyWithZone:device on convolution is called, data source copyWithZone:device will be called if data source object responds to this selector. If not, copyWithZone: will be called if data source responds to it. Otherwise, it is simply retained. This is to allow application to make a separate copy of data source in convolution when convolution itself is coplied, for example when copying training graph for running on second GPU so that weights update on two different GPUs dont end up stomping same data source.

§Safety

zone must be a valid pointer or null.

Source

unsafe fn weightsLayout(&self) -> MPSCNNConvolutionWeightsLayout
where Self: Sized + Message,

Layout of weights returned by data source. Currently only OHWI layout is supported which is default. See MPSCNNConvolutionWeightsLayout above.

Source

unsafe fn kernelWeightsDataType(&self) -> MPSDataType
where Self: Sized + Message,

Available on crate features MPSCore and MPSCoreTypes only.

Alerts MPS what weight precision to use in the CNNConvolution kernel

If precision of weights returned by dataType does not match precision returned by kernelWeightsDataType, weights are converted to precision specified by kernelWeightsDataType before being passed to kernel. For MPSCNNConvolution, dataType precisions of MPSDataTypeUInt8 or MPSDataTypeFloat16 must return a kernelWeightsDataType of MPSDataTypeFloat16. dataType precisions of MPSDataTypeFloat32 may return kernelWeightsDataType of MPSDataTypeFloat16 or MPSDataTypeFloat32. When kernelWeightsDataType returns MPSDataTypeFloat32 the accumulatorPrecisionOption on the CNNConvolution object must be set to MPSNNConvolutionAccumulatorPrecisionOptionFloat. When kernelWeightsDataType is unimplemented the kernel will use float16 precision. MPSCNNBinaryConvolution assumes weights to be of type MPSDataTypeUInt32 always, and the kernelWeightsDataType is unused.

Trait Implementations§

Source§

impl ProtocolType for dyn MPSCNNConvolutionDataSource

Available on crate feature MPSNeuralNetwork only.
Source§

const NAME: &'static str = "MPSCNNConvolutionDataSource"

The name of the Objective-C protocol that this type represents. Read more
Source§

fn protocol() -> Option<&'static AnyProtocol>

Get a reference to the Objective-C protocol object that this type represents. Read more
Source§

impl<T> ImplementedBy<T> for dyn MPSCNNConvolutionDataSource

Available on crate feature MPSNeuralNetwork only.

Implementations on Foreign Types§

Source§

impl<T> MPSCNNConvolutionDataSource for ProtocolObject<T>

Available on crate feature MPSNeuralNetwork only.

Implementors§