pub trait MatmulArgs:
Send
+ Sync
+ 'static
+ Clone {
type Input<EI: Numeric>: LaunchArg + CubeType;
type Output<EO: Numeric>: LaunchArg + CubeType;
type State<EI: Numeric, EO: Numeric>: CubeType;
Show 48 methods
// Required methods
fn init_state<EI: Numeric, EO: Numeric>(
input: &Self::Input<EI>,
output: &mut Self::Output<EO>,
) -> Self::State<EI, EO>;
fn read_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
coordinate: u32,
) -> Line<EI>;
fn read_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
coordinate: u32,
) -> Line<EI>;
fn read_window_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
start: u32,
end: u32,
) -> Slice<Line<EI>>;
fn read_window_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
start: u32,
end: u32,
) -> Slice<Line<EI>>;
fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> TensorMap<EI>;
fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> TensorMap<EI>;
fn write_out<EI: Numeric, EO: Numeric>(
state: &mut Self::State<EI, EO>,
coordinate: u32,
value: Line<EO>,
);
fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32;
fn buffer_len_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> u32;
fn buffer_len_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> u32;
fn buffer_len_out<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> u32;
fn shape_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn shape_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn shape_out<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn stride_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn stride_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn stride_out<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32;
fn quantization<MP: MatmulPrecision>(
state: &Self::State<MP::EI, MP::EO>,
) -> Quantization<MP>;
fn __expand_init_state<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
input: <Self::Input<EI> as CubeType>::ExpandType,
output: <Self::Output<EO> as CubeType>::ExpandType,
) -> <Self::State<EI, EO> as CubeType>::ExpandType;
fn __expand_read_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
coordinate: <u32 as CubeType>::ExpandType,
) -> <Line<EI> as CubeType>::ExpandType;
fn __expand_read_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
coordinate: <u32 as CubeType>::ExpandType,
) -> <Line<EI> as CubeType>::ExpandType;
fn __expand_read_window_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
start: <u32 as CubeType>::ExpandType,
end: <u32 as CubeType>::ExpandType,
) -> <Slice<Line<EI>> as CubeType>::ExpandType;
fn __expand_read_window_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
start: <u32 as CubeType>::ExpandType,
end: <u32 as CubeType>::ExpandType,
) -> <Slice<Line<EI>> as CubeType>::ExpandType;
fn __expand_as_tensor_map_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <TensorMap<EI> as CubeType>::ExpandType;
fn __expand_as_tensor_map_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <TensorMap<EI> as CubeType>::ExpandType;
fn __expand_write_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
coordinate: <u32 as CubeType>::ExpandType,
value: <Line<EO> as CubeType>::ExpandType,
) -> <() as CubeType>::ExpandType;
fn __expand_rank_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_rank_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_rank_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_len_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_len_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_len_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_buffer_len_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_buffer_len_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_buffer_len_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_shape_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_shape_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_shape_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_stride_lhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_stride_rhs<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_stride_out<EI: Numeric, EO: Numeric>(
scope: &mut Scope,
state: <Self::State<EI, EO> as CubeType>::ExpandType,
axis: <u32 as CubeType>::ExpandType,
) -> <u32 as CubeType>::ExpandType;
fn __expand_quantization<MP: MatmulPrecision>(
scope: &mut Scope,
state: <Self::State<MP::EI, MP::EO> as CubeType>::ExpandType,
) -> <Quantization<MP> as CubeType>::ExpandType;
}
Expand description
Arguments for the matrix multiplication algorithm.
Required Associated Types§
Sourcetype State<EI: Numeric, EO: Numeric>: CubeType
type State<EI: Numeric, EO: Numeric>: CubeType
Inner state that is used to create tensor inputs and tensor outputs .
Required Methods§
Sourcefn init_state<EI: Numeric, EO: Numeric>(
input: &Self::Input<EI>,
output: &mut Self::Output<EO>,
) -> Self::State<EI, EO>
fn init_state<EI: Numeric, EO: Numeric>( input: &Self::Input<EI>, output: &mut Self::Output<EO>, ) -> Self::State<EI, EO>
Init the state.
Sourcefn read_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
coordinate: u32,
) -> Line<EI>
fn read_lhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, coordinate: u32, ) -> Line<EI>
Read the line of the lhs tensor using the state at the given coordinate.
Sourcefn read_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
coordinate: u32,
) -> Line<EI>
fn read_rhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, coordinate: u32, ) -> Line<EI>
Read the line of the rhs tensor using the state at the given coordinate.
Sourcefn read_window_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
start: u32,
end: u32,
) -> Slice<Line<EI>>
fn read_window_lhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, start: u32, end: u32, ) -> Slice<Line<EI>>
Read the line of the lhs tensor using the state at the given coordinate.
Sourcefn read_window_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
start: u32,
end: u32,
) -> Slice<Line<EI>>
fn read_window_rhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, start: u32, end: u32, ) -> Slice<Line<EI>>
Read the line of the rhs tensor using the state at the given coordinate.
Sourcefn as_tensor_map_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> TensorMap<EI>
fn as_tensor_map_lhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, ) -> TensorMap<EI>
Reinterpret lhs as tensor map
Sourcefn as_tensor_map_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
) -> TensorMap<EI>
fn as_tensor_map_rhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, ) -> TensorMap<EI>
Reinterpret rhs as tensor map
Sourcefn write_out<EI: Numeric, EO: Numeric>(
state: &mut Self::State<EI, EO>,
coordinate: u32,
value: Line<EO>,
)
fn write_out<EI: Numeric, EO: Numeric>( state: &mut Self::State<EI, EO>, coordinate: u32, value: Line<EO>, )
Write the line to the output at the given coordinate using the state.
Sourcefn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn rank_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the rank of the lhs tensor using the state.
Sourcefn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn rank_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the rank of the rhs tensor using the state.
Sourcefn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn rank_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the rank of the out tensor using the state.
Sourcefn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the length of the lhs tensor using the state.
Sourcefn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the length of the rhs tensor using the state.
Sourcefn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the length of the out tensor using the state.
Sourcefn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn buffer_len_lhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the buffer length of the lhs tensor using the state.
Sourcefn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn buffer_len_rhs<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the buffer length of the rhs tensor using the state.
Sourcefn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
fn buffer_len_out<EI: Numeric, EO: Numeric>(state: &Self::State<EI, EO>) -> u32
Get the buffer length of the out tensor using the state.
Sourcefn shape_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn shape_lhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the shape of the lhs tensor using the state.
Sourcefn shape_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn shape_rhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the shape of the rhs tensor using the state.
Sourcefn shape_out<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn shape_out<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the shape of the out tensor using the state.
Sourcefn stride_lhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn stride_lhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the stride of the lhs tensor using the state.
Sourcefn stride_rhs<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn stride_rhs<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the stride of the rhs tensor using the state.
Sourcefn stride_out<EI: Numeric, EO: Numeric>(
state: &Self::State<EI, EO>,
axis: u32,
) -> u32
fn stride_out<EI: Numeric, EO: Numeric>( state: &Self::State<EI, EO>, axis: u32, ) -> u32
Get the stride of the out tensor using the state.
Sourcefn quantization<MP: MatmulPrecision>(
state: &Self::State<MP::EI, MP::EO>,
) -> Quantization<MP>
fn quantization<MP: MatmulPrecision>( state: &Self::State<MP::EI, MP::EO>, ) -> Quantization<MP>
It is the responsibility of the caller to ensure it is safe to call this function. That is, when a matmul is indeed quantized. Else, it will most likely results in out-of-bound memory access.
fn __expand_init_state<EI: Numeric, EO: Numeric>( scope: &mut Scope, input: <Self::Input<EI> as CubeType>::ExpandType, output: <Self::Output<EO> as CubeType>::ExpandType, ) -> <Self::State<EI, EO> as CubeType>::ExpandType
fn __expand_read_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, coordinate: <u32 as CubeType>::ExpandType, ) -> <Line<EI> as CubeType>::ExpandType
fn __expand_read_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, coordinate: <u32 as CubeType>::ExpandType, ) -> <Line<EI> as CubeType>::ExpandType
fn __expand_read_window_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, start: <u32 as CubeType>::ExpandType, end: <u32 as CubeType>::ExpandType, ) -> <Slice<Line<EI>> as CubeType>::ExpandType
fn __expand_read_window_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, start: <u32 as CubeType>::ExpandType, end: <u32 as CubeType>::ExpandType, ) -> <Slice<Line<EI>> as CubeType>::ExpandType
fn __expand_as_tensor_map_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <TensorMap<EI> as CubeType>::ExpandType
fn __expand_as_tensor_map_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <TensorMap<EI> as CubeType>::ExpandType
fn __expand_write_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, coordinate: <u32 as CubeType>::ExpandType, value: <Line<EO> as CubeType>::ExpandType, ) -> <() as CubeType>::ExpandType
fn __expand_rank_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_rank_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_rank_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_len_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_len_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_len_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_buffer_len_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_buffer_len_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_buffer_len_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_shape_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_shape_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_shape_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_stride_lhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_stride_rhs<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_stride_out<EI: Numeric, EO: Numeric>( scope: &mut Scope, state: <Self::State<EI, EO> as CubeType>::ExpandType, axis: <u32 as CubeType>::ExpandType, ) -> <u32 as CubeType>::ExpandType
fn __expand_quantization<MP: MatmulPrecision>( scope: &mut Scope, state: <Self::State<MP::EI, MP::EO> as CubeType>::ExpandType, ) -> <Quantization<MP> as CubeType>::ExpandType
Dyn Compatibility§
This trait is not dyn compatible.
In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.