use crate::BoolElement;
use crate::{CubeBackend, CubeRuntime, FloatElement, IntElement, kernel, tensor::CubeTensor};
use burn_backend::tensor::{BoolTensor, FloatTensor, IntTensor, QuantizedTensor};
use burn_backend::{DType, Shape};
use burn_cubecl_fusion::optim::reduce::ReduceSettings;
use burn_cubecl_fusion::optim::reduce_broadcasted::ReduceBroadcastedFuser;
use burn_cubecl_fusion::{
CubeFusionHandle, FallbackOperation,
optim::{
CubeOptimization, CubeOptimizationState,
elemwise::{ElementWiseFuser, ElemwiseOptimization},
matmul::{MatmulFuser, MatmulOptimization},
reduce::{ReduceFuser, ReduceOptimization},
reduce_broadcasted::ReduceBroadcastedOptimization,
},
};
use burn_fusion::UnfusedOp;
use burn_fusion::{
FusionBackend, FusionRuntime,
stream::{Operation, OrderedExecution},
};
use burn_ir::{BackendIr, TensorHandle};
use burn_std::Metadata;
use core::marker::PhantomData;
use std::sync::Arc;
impl<R> burn_fusion::Optimization<FusionCubeRuntime<R>> for CubeOptimization<R>
where
R: CubeRuntime,
{
fn execute(
&mut self,
context: &mut burn_fusion::stream::Context<
<FusionCubeRuntime<R> as FusionRuntime>::FusionHandle,
>,
execution: &OrderedExecution<FusionCubeRuntime<R>>,
) {
match self {
Self::ElementWise(op) => op.execute(context),
Self::Matmul(op) => op.execute(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
Self::Reduce(op) => op.execute(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
Self::ReduceBroadcasted(op) => op.execute(context, |index| {
let operation = execution.operation_within_optimization(index);
Box::new(FallbackOperationWrapper::new(operation))
}),
}
}
fn to_state(&self) -> CubeOptimizationState {
self.to_opt_state()
}
fn from_state(device: &R::Device, state: CubeOptimizationState) -> Self {
match state {
CubeOptimizationState::ElementWise(state) => {
Self::ElementWise(ElemwiseOptimization::from_state(device, state))
}
CubeOptimizationState::Matmul(state) => {
Self::Matmul(MatmulOptimization::from_state(device, state))
}
CubeOptimizationState::Reduce(state) => {
Self::Reduce(ReduceOptimization::from_state(device, state))
}
CubeOptimizationState::ReduceBroadcasted(state) => {
Self::ReduceBroadcasted(ReduceBroadcastedOptimization::from_state(device, state))
}
}
}
}
struct FallbackOperationWrapper<O: Clone> {
operation: O,
}
impl<O: Clone> FallbackOperationWrapper<O> {
fn new(op: O) -> Self {
Self { operation: op }
}
}
impl<R: CubeRuntime> FallbackOperation<R>
for FallbackOperationWrapper<Arc<dyn Operation<FusionCubeRuntime<R>>>>
{
fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
self.operation.as_ref().execute(&mut context.handles);
}
}
impl<R: CubeRuntime> FallbackOperation<R>
for FallbackOperationWrapper<UnfusedOp<FusionCubeRuntime<R>>>
{
fn run(&self, context: &mut burn_fusion::stream::Context<CubeFusionHandle<R>>) {
self.operation.execute(&mut context.handles);
}
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> BackendIr
for CubeBackend<R, F, I, BT>
{
type Handle = CubeFusionHandle<R>;
fn float_tensor(handle: TensorHandle<Self::Handle>) -> FloatTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn int_tensor(handle: TensorHandle<Self::Handle>) -> IntTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn bool_tensor(handle: TensorHandle<Self::Handle>) -> BoolTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn quantized_tensor(handle: TensorHandle<Self::Handle>) -> QuantizedTensor<Self> {
into_tensor(handle.handle, handle.shape)
}
fn float_tensor_handle(tensor: FloatTensor<Self>) -> Self::Handle {
tensor.into()
}
fn int_tensor_handle(tensor: IntTensor<Self>) -> Self::Handle {
tensor.into()
}
fn bool_tensor_handle(tensor: BoolTensor<Self>) -> Self::Handle {
tensor.into()
}
fn quantized_tensor_handle(tensor: QuantizedTensor<Self>) -> Self::Handle {
tensor.into()
}
}
impl<R: CubeRuntime> FusionRuntime for FusionCubeRuntime<R> {
type OptimizationState = CubeOptimizationState;
type Optimization = CubeOptimization<R>;
type FusionHandle = CubeFusionHandle<R>;
type FusionDevice = R::CubeDevice;
fn fusers(device: R::Device) -> Vec<Box<dyn burn_fusion::OperationFuser<Self::Optimization>>> {
vec![
Box::new(ElementWiseFuser::new(device.clone())),
Box::new(MatmulFuser::new(device.clone())),
Box::new(ReduceFuser::new(device.clone(), ReduceSettings::Always)),
Box::new(ReduceBroadcastedFuser::new(device.clone())),
]
}
}
#[derive(Debug)]
pub struct FusionCubeRuntime<R: CubeRuntime> {
_b: PhantomData<R>,
}
impl<R: CubeRuntime, F: FloatElement, I: IntElement, BT: BoolElement> FusionBackend
for CubeBackend<R, F, I, BT>
{
type FusionRuntime = FusionCubeRuntime<R>;
type FullPrecisionBackend = CubeBackend<R, f32, i32, BT>;
fn cast_float(tensor: FloatTensor<Self>, dtype: DType) -> Self::Handle {
kernel::cast(tensor, dtype).into()
}
}
fn into_tensor<R: CubeRuntime>(handle: CubeFusionHandle<R>, shape: Shape) -> CubeTensor<R> {
CubeTensor {
client: handle.client.clone(),
handle: handle.handle.clone(),
device: handle.device.clone(),
meta: Box::new(Metadata::new(shape, handle.strides.clone())),
dtype: handle.dtype,
qparams: handle.qparams.clone(),
}
}
impl<R: CubeRuntime> From<CubeTensor<R>> for CubeFusionHandle<R> {
fn from(value: CubeTensor<R>) -> Self {
Self {
client: value.client.clone(),
handle: value.handle.clone(),
device: value.device.clone(),
strides: value.meta.strides.clone(),
dtype: value.dtype,
qparams: value.qparams.clone(),
}
}
}