Trait BatchMatmulFamily

Source
pub trait BatchMatmulFamily:
    'static
    + Send
    + Sync {
    type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>;
    type Config: BatchConfig;

    // Required methods
    fn setup<MP: MatmulPrecision, R: Runtime>(
        client: &ComputeClient<R::Server, R::Channel>,
        problem: &MatmulProblem,
        selection: &MatmulSelection,
        line_sizes: &MatmulLineSizes,
    ) -> Result<Self::Config, MatmulSetupError>;
    unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>(
        client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>,
        cube_dim: CubeDim,
        cube_count: CubeCount,
        input: InputRuntimeArg<'a, MS, R>,
        output: OutputRuntimeArg<'a, MS, R>,
        cube_count_input: CubeCountInputArgs<'a, R>,
        config: Self::Config,
    );

    // Provided method
    fn filter_line_sizes(
        available_line_sizes: AvailableLineSizes,
    ) -> AvailableLineSizes { ... }
}
Expand description

A family of matmuls working with any precision.

Required Associated Types§

Source

type Matmul<MP: MatmulPrecision>: BatchMatmul<MP, Config = Self::Config>

The specific BatchMatmul implementation associated with this family.

Source

type Config: BatchConfig

The configuration type associated with this matmul family.

Required Methods§

Source

fn setup<MP: MatmulPrecision, R: Runtime>( client: &ComputeClient<R::Server, R::Channel>, problem: &MatmulProblem, selection: &MatmulSelection, line_sizes: &MatmulLineSizes, ) -> Result<Self::Config, MatmulSetupError>

Constructs the configuration based on the matmul problem, selection, and line sizes.

This function may return an error if the configuration cannot be supported on the current runtime.

Source

unsafe fn launch_unchecked<'a, MS: MatmulSpec, R: Runtime>( client: &ComputeClient<<R as Runtime>::Server, <R as Runtime>::Channel>, cube_dim: CubeDim, cube_count: CubeCount, input: InputRuntimeArg<'a, MS, R>, output: OutputRuntimeArg<'a, MS, R>, cube_count_input: CubeCountInputArgs<'a, R>, config: Self::Config, )

Entry point

§Safety

Out-of-bounds can happen

Provided Methods§

Source

fn filter_line_sizes( available_line_sizes: AvailableLineSizes, ) -> AvailableLineSizes

Filters out line sizes that are incompatible with this matmul family.

By default, returns the input unchanged.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<GMM: GlobalMatmulFamily, S: GlobalPartitionMatmul> BatchMatmulFamily for PartitionedBatchMatmulFamily<GMM, S>

Source§

type Matmul<MP: MatmulPrecision> = PartitionedBatchMatmul<MP, <GMM as GlobalMatmulFamily>::Matmul<MP>, S>

Source§

type Config = PartitionedBatchConfig<<GMM as GlobalMatmulFamily>::Config>