use crate::launch::{
ConcreteInputsFactory, ConcreteOutputFactory, InputArg, InputRuntimeArg, MatmulArgs, OutputArg,
OutputRuntimeArg,
};
use crate::{
definition::MatmulProblem, definition::MatmulSetupError, definition::MatmulVectorSizes,
definition::cube_mapping_launch,
};
use crate::{
routines::LaunchInfo,
routines::{BlueprintStrategy, Routine},
{definition::MatmulElems, launch::ConfigRuntimeArg},
};
use cubecl::{
prelude::TensorBinding,
{Runtime, client::ComputeClient},
};
use cubek_std::InputBinding;
#[allow(clippy::result_large_err, clippy::too_many_arguments)]
pub fn launch_kernel_concrete<MA: MatmulArgs<Config = ()>, R: Runtime, A: Routine<()>>(
client: &ComputeClient<R>,
lhs: InputBinding<R>,
rhs: InputBinding<R>,
out: TensorBinding<R>,
problem: MatmulProblem,
vector_sizes: MatmulVectorSizes,
blueprint_strategy: &BlueprintStrategy<(), A>,
dtypes: &mut MatmulElems,
) -> Result<(), MatmulSetupError>
where
InputArg<MA>: ConcreteInputsFactory<A>,
OutputArg<MA>: ConcreteOutputFactory<A>,
{
let mut view_vector_sizes = vector_sizes;
if let InputBinding::Quantized { scheme, .. } = lhs {
view_vector_sizes.lhs *= scheme.num_quants();
}
if let InputBinding::Quantized { scheme, .. } = rhs {
view_vector_sizes.rhs *= scheme.num_quants();
}
let device_settings = A::device_settings(client, view_vector_sizes);
let expand_info = A::expand_blueprint(&problem, &device_settings, blueprint_strategy)?;
let launch_info = A::prepare(&problem, &device_settings, expand_info)?;
let input = <InputArg<MA> as ConcreteInputsFactory<A>>::create(
lhs,
rhs,
&launch_info.blueprint,
&problem,
&launch_info.vector_sizes,
dtypes,
);
let output = <OutputArg<MA> as ConcreteOutputFactory<A>>::create(
out,
&launch_info.blueprint,
&problem,
&launch_info.vector_sizes,
dtypes,
);
launch_kernel::<MA, R, A>(client, input, output, (), launch_info)
}
#[allow(clippy::too_many_arguments)]
pub fn launch_kernel_virtual<MA: MatmulArgs, R: Runtime, A: Routine<MA::Config>>(
client: &ComputeClient<R>,
input: InputRuntimeArg<MA, R>,
output: OutputRuntimeArg<MA, R>,
config: ConfigRuntimeArg<MA, R>,
problem: MatmulProblem,
view_vector_sizes: MatmulVectorSizes,
blueprint_strategy: &BlueprintStrategy<MA::Config, A>,
) -> Result<(), MatmulSetupError> {
let device_settings = A::device_settings(client, view_vector_sizes);
let expand_info = A::expand_blueprint(&problem, &device_settings, blueprint_strategy)?;
let launch_info = A::prepare(&problem, &device_settings, expand_info)?;
launch_kernel::<MA, R, A>(client, input, output, config, launch_info)
}
#[allow(clippy::too_many_arguments)]
pub fn launch_kernel<MA: MatmulArgs, R: Runtime, A: Routine<MA::Config>>(
client: &ComputeClient<R>,
input: InputRuntimeArg<MA, R>,
output: OutputRuntimeArg<MA, R>,
config: ConfigRuntimeArg<MA, R>,
launch_info: LaunchInfo<A::Blueprint>,
) -> Result<(), MatmulSetupError> {
A::launch::<MA, R>(
client,
launch_info.cube_dim,
launch_info.cube_count_plan.resolve(),
launch_info.address_type,
input,
output,
config,
cube_mapping_launch(&launch_info.cube_count_plan),
launch_info.blueprint,
&launch_info.dtypes,
&launch_info.vector_sizes,
)
}