Trait Convolution

Source
pub trait Convolution<MP: MatmulPrecision>:
    'static
    + Send
    + Sync {
    type LhsLoader: CubeType;
    type RhsLoader: CubeType;
    type Config: ConvGemmConfig;
    type AccumulatorLoader: AccumulatorLoader<MP>;
    type Out: OutputLoader<MP::EO>;
    type Accumulator: CubeType;

    // Required methods
    fn execute(
        lhs_loader: Self::LhsLoader,
        rhs_loader: Self::RhsLoader,
        acc_loader: Self::AccumulatorLoader,
        unloader: Self::Out,
        acc: &mut Self::Accumulator,
        k_range: (u32, u32),
        config: Self::Config,
    );
    fn init_lhs_loader(
        lhs: VirtualTensor<MP::EI>,
        x_offset: u32,
        y_offset: u32,
        runtime_args: &RuntimeArgs,
        config: Self::Config,
    ) -> Self::LhsLoader;
    fn init_rhs_loader(
        rhs: VirtualTensor<MP::EI>,
        x_offset: u32,
        y_offset: u32,
        runtime_args: &RuntimeArgs,
        config: Self::Config,
    ) -> Self::RhsLoader;
    fn init_bias_loader(
        bias: CubeOption<VirtualTensor<MP::EO>>,
        n_offset: u32,
        config: Self::Config,
    ) -> Self::AccumulatorLoader;
    fn init_unloader(
        out: VirtualTensor<MP::EO, ReadWrite>,
        x_offset: u32,
        y_offset: u32,
    ) -> Self::Out;
    fn init_accumulator(config: Self::Config) -> Self::Accumulator;
    fn __expand_execute(
        context: &mut Scope,
        lhs_loader: <Self::LhsLoader as CubeType>::ExpandType,
        rhs_loader: <Self::RhsLoader as CubeType>::ExpandType,
        acc_loader: <Self::AccumulatorLoader as CubeType>::ExpandType,
        unloader: <Self::Out as CubeType>::ExpandType,
        acc: <Self::Accumulator as CubeType>::ExpandType,
        k_range: <(u32, u32) as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <() as CubeType>::ExpandType;
    fn __expand_init_lhs_loader(
        context: &mut Scope,
        lhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType,
        x_offset: <u32 as CubeType>::ExpandType,
        y_offset: <u32 as CubeType>::ExpandType,
        runtime_args: <RuntimeArgs as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <Self::LhsLoader as CubeType>::ExpandType;
    fn __expand_init_rhs_loader(
        context: &mut Scope,
        rhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType,
        x_offset: <u32 as CubeType>::ExpandType,
        y_offset: <u32 as CubeType>::ExpandType,
        runtime_args: <RuntimeArgs as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <Self::RhsLoader as CubeType>::ExpandType;
    fn __expand_init_bias_loader(
        context: &mut Scope,
        bias: <CubeOption<VirtualTensor<MP::EO>> as CubeType>::ExpandType,
        n_offset: <u32 as CubeType>::ExpandType,
        config: Self::Config,
    ) -> <Self::AccumulatorLoader as CubeType>::ExpandType;
    fn __expand_init_unloader(
        context: &mut Scope,
        out: <VirtualTensor<MP::EO, ReadWrite> as CubeType>::ExpandType,
        x_offset: <u32 as CubeType>::ExpandType,
        y_offset: <u32 as CubeType>::ExpandType,
    ) -> <Self::Out as CubeType>::ExpandType;
    fn __expand_init_accumulator(
        context: &mut Scope,
        config: Self::Config,
    ) -> <Self::Accumulator as CubeType>::ExpandType;
}

Required Associated Types§

Required Methods§

Source

fn execute( lhs_loader: Self::LhsLoader, rhs_loader: Self::RhsLoader, acc_loader: Self::AccumulatorLoader, unloader: Self::Out, acc: &mut Self::Accumulator, k_range: (u32, u32), config: Self::Config, )

Performs the convolution over data loaded by the LHS and RHS loaders, over the range given for K, and stores with using the output unloader.

To compute the whole range of k values, use k_range=(0, K) where K is the K dimension of LHS and RHS.

Source

fn init_lhs_loader( lhs: VirtualTensor<MP::EI>, x_offset: u32, y_offset: u32, runtime_args: &RuntimeArgs, config: Self::Config, ) -> Self::LhsLoader

Source

fn init_rhs_loader( rhs: VirtualTensor<MP::EI>, x_offset: u32, y_offset: u32, runtime_args: &RuntimeArgs, config: Self::Config, ) -> Self::RhsLoader

Source

fn init_bias_loader( bias: CubeOption<VirtualTensor<MP::EO>>, n_offset: u32, config: Self::Config, ) -> Self::AccumulatorLoader

Source

fn init_unloader( out: VirtualTensor<MP::EO, ReadWrite>, x_offset: u32, y_offset: u32, ) -> Self::Out

Source

fn init_accumulator(config: Self::Config) -> Self::Accumulator

Source

fn __expand_execute( context: &mut Scope, lhs_loader: <Self::LhsLoader as CubeType>::ExpandType, rhs_loader: <Self::RhsLoader as CubeType>::ExpandType, acc_loader: <Self::AccumulatorLoader as CubeType>::ExpandType, unloader: <Self::Out as CubeType>::ExpandType, acc: <Self::Accumulator as CubeType>::ExpandType, k_range: <(u32, u32) as CubeType>::ExpandType, config: Self::Config, ) -> <() as CubeType>::ExpandType

Source

fn __expand_init_lhs_loader( context: &mut Scope, lhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType, x_offset: <u32 as CubeType>::ExpandType, y_offset: <u32 as CubeType>::ExpandType, runtime_args: <RuntimeArgs as CubeType>::ExpandType, config: Self::Config, ) -> <Self::LhsLoader as CubeType>::ExpandType

Source

fn __expand_init_rhs_loader( context: &mut Scope, rhs: <VirtualTensor<MP::EI> as CubeType>::ExpandType, x_offset: <u32 as CubeType>::ExpandType, y_offset: <u32 as CubeType>::ExpandType, runtime_args: <RuntimeArgs as CubeType>::ExpandType, config: Self::Config, ) -> <Self::RhsLoader as CubeType>::ExpandType

Source

fn __expand_init_bias_loader( context: &mut Scope, bias: <CubeOption<VirtualTensor<MP::EO>> as CubeType>::ExpandType, n_offset: <u32 as CubeType>::ExpandType, config: Self::Config, ) -> <Self::AccumulatorLoader as CubeType>::ExpandType

Source

fn __expand_init_unloader( context: &mut Scope, out: <VirtualTensor<MP::EO, ReadWrite> as CubeType>::ExpandType, x_offset: <u32 as CubeType>::ExpandType, y_offset: <u32 as CubeType>::ExpandType, ) -> <Self::Out as CubeType>::ExpandType

Source

fn __expand_init_accumulator( context: &mut Scope, config: Self::Config, ) -> <Self::Accumulator as CubeType>::ExpandType

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§