Algorithm

Trait Algorithm 

Source
pub trait Algorithm {
    type TileMatmul: TileMatmulFamily;
    type StageMatmul: StageMatmulFamily;
    type GlobalConvolution: GlobalConvolutionFamily;
    type Args: MatmulArgs;

    // Required methods
    fn num_stages() -> NumStages;
    fn into_tensor_handle<R: Runtime, E: Numeric>(
        client: &ComputeClient<R::Server>,
        handle: &TensorHandleRef<'_, R>,
        ident: MatmulIdent,
    ) -> TensorHandle<R, E>;
    fn selection<R: Runtime>(
        client: &ComputeClient<R::Server>,
        problem: &ConvolutionProblem,
        plane_dim: u32,
        matmul_elems: MatmulElems,
    ) -> Result<MatmulSelection, MatmulSetupError>;

    // Provided methods
    fn cube_count(
        selection: &MatmulSelection,
        problem: &ConvolutionProblem,
    ) -> CubeCount { ... }
    fn multi_row_strategy() -> MultiRowStrategy { ... }
    fn loading_precompute_strategy() -> LoadingPrecomputeStrategy { ... }
    fn reader_mode() -> ReaderMode { ... }
    fn load_specialization() -> LoadSpecializationConfig { ... }
    fn partition_buffering_strategy() -> PartitionBuffering { ... }
    fn setup<R: Runtime, MP: MatmulPrecision>(
        client: &ComputeClient<R::Server>,
        problem: &ConvolutionProblem,
        selection: &MatmulSelection,
        line_sizes: &MatmulLineSizes,
    ) -> Result<GlobalConfig<Self::GlobalConvolution>, MatmulSetupError> { ... }
    fn filter_line_sizes(
        available_line_sizes: AvailableLineSizes,
    ) -> AvailableLineSizes { ... }
}
Expand description

Specifications for a convolution algorithm

Required Associated Types§

Required Methods§

Source

fn num_stages() -> NumStages

Source

fn into_tensor_handle<R: Runtime, E: Numeric>( client: &ComputeClient<R::Server>, handle: &TensorHandleRef<'_, R>, ident: MatmulIdent, ) -> TensorHandle<R, E>

Source

fn selection<R: Runtime>( client: &ComputeClient<R::Server>, problem: &ConvolutionProblem, plane_dim: u32, matmul_elems: MatmulElems, ) -> Result<MatmulSelection, MatmulSetupError>

Provided Methods§

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<TMM: TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = CubeOption<Strided>, OutTile = Strided>> Algorithm for MultiStageTmaConvAlgorithm<TMM>

Source§

impl<TMM: TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = CubeOption<Strided>, OutTile = Strided>> Algorithm for SimpleConvAlgorithm<TMM>

Source§

impl<TMM: TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = CubeOption<Strided>, OutTile = Strided>> Algorithm for SimpleTmaConvAlgorithm<TMM>