burn_dragon_vision 0.4.0

Foveation and vision sampling utilities for burn dragon
Documentation
use std::any::{Any, TypeId};

use burn::tensor::Tensor as BurnTensor;
use burn::tensor::backend::Backend as BackendTrait;
use burn::tensor::{DType, Shape, TensorPrimitive};
use burn_cubecl::fusion::FusionCubeRuntime;
use burn_cubecl::kernel::into_contiguous;
use burn_cubecl::ops::numeric::empty_device;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime};
use burn_fusion::FusionTensor;
use burn_fusion::stream::StreamId;
use burn_wgpu::WgpuRuntime;
#[cfg(feature = "cuda")]
use cubecl::cuda::CudaRuntime;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

pub(crate) fn supports_backend<B: BackendTrait>() -> bool
where
    B::FloatTensorPrimitive: 'static,
{
    #[cfg(feature = "cuda")]
    {
        matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u32>>,
        >() || matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u8>>,
        >() || matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>()
            || matches_type::<
                B::FloatTensorPrimitive,
                FusionTensor<FusionCubeRuntime<CudaRuntime, u32>>,
            >()
            || matches_type::<
                B::FloatTensorPrimitive,
                FusionTensor<FusionCubeRuntime<CudaRuntime, u8>>,
            >()
            || matches_type::<B::FloatTensorPrimitive, CubeTensor<CudaRuntime>>()
    }
    #[cfg(not(feature = "cuda"))]
    {
        matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u32>>,
        >() || matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u8>>,
        >() || matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>()
    }
}

pub(crate) fn try_weighted_sum_tokens_cubecl<B: BackendTrait>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B::FloatTensorPrimitive: 'static,
{
    if !supports_backend::<B>() {
        return None;
    }
    let [batch, out_tokens, in_tokens] = weights.shape().dims::<3>();
    let [tok_batch, tok_in, dim] = tokens.shape().dims::<3>();
    if batch == 0 || out_tokens == 0 || in_tokens == 0 || dim == 0 {
        return None;
    }
    if batch != tok_batch || in_tokens != tok_in {
        return None;
    }

    if let Some(result) =
        try_weighted_sum_tokens_cubecl_fusion::<B, u32, WgpuRuntime>(weights, tokens)
    {
        return Some(result);
    }
    if let Some(result) =
        try_weighted_sum_tokens_cubecl_fusion::<B, u8, WgpuRuntime>(weights, tokens)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    if let Some(result) =
        try_weighted_sum_tokens_cubecl_fusion::<B, u32, CudaRuntime>(weights, tokens)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    if let Some(result) =
        try_weighted_sum_tokens_cubecl_fusion::<B, u8, CudaRuntime>(weights, tokens)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    {
        if let Some(result) =
            try_weighted_sum_tokens_cubecl_direct::<B, CudaRuntime>(weights, tokens)
        {
            return Some(result);
        }
    }
    try_weighted_sum_tokens_cubecl_direct::<B, WgpuRuntime>(weights, tokens)
}

fn try_weighted_sum_tokens_cubecl_fusion<B, BT, R>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
    R: CubeRuntime + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<R, BT>>>() {
        return None;
    }
    let prim_weights = weights.clone().into_primitive().tensor();
    let fusion_weights: FusionTensor<FusionCubeRuntime<R, BT>> =
        try_cast_primitive::<B, _>(prim_weights)?;
    let fusion_client = fusion_weights.client.clone();
    let weights =
        fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_weights);
    if weights.dtype != DType::F32 {
        return None;
    }

    let prim_tokens = tokens.clone().into_primitive().tensor();
    let fusion_tokens: FusionTensor<FusionCubeRuntime<R, BT>> =
        try_cast_primitive::<B, _>(prim_tokens)?;
    let tokens = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_tokens);
    if tokens.dtype != DType::F32 {
        return None;
    }

    let output = weighted_sum_tokens_cubecl_runtime::<R>(weights, tokens);
    let shape = output.shape.clone();
    let dtype = output.dtype;
    let handle = output.into();
    let fusion_out = fusion_client.register_tensor(handle, shape, StreamId::current(), dtype);
    let out_prim = try_cast_backend::<B, _>(fusion_out)?;
    Some(BurnTensor::<B, 3>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn try_weighted_sum_tokens_cubecl_direct<B, R>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    R: CubeRuntime + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, CubeTensor<R>>() {
        return None;
    }
    let prim_weights = weights.clone().into_primitive().tensor();
    let weights: CubeTensor<R> = try_cast_primitive::<B, _>(prim_weights)?;
    if weights.dtype != DType::F32 {
        return None;
    }
    let prim_tokens = tokens.clone().into_primitive().tensor();
    let tokens: CubeTensor<R> = try_cast_primitive::<B, _>(prim_tokens)?;
    if tokens.dtype != DType::F32 {
        return None;
    }

    let output = weighted_sum_tokens_cubecl_runtime::<R>(weights, tokens);
    let out_prim = try_cast_backend::<B, _>(output)?;
    Some(BurnTensor::<B, 3>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn weighted_sum_tokens_cubecl_runtime<R: CubeRuntime>(
    weights: CubeTensor<R>,
    tokens: CubeTensor<R>,
) -> CubeTensor<R> {
    let weights = into_contiguous(weights);
    let tokens = into_contiguous(tokens);
    let [batch, out_tokens, _] = weights.shape.dims::<3>();
    let dim = tokens.shape.dims::<3>()[2];

    let client = weights.client.clone();
    let device = weights.device.clone();
    let output =
        empty_device::<R, f32>(client.clone(), device, Shape::new([batch, out_tokens, dim]));
    let out_elems = output.shape.num_elements();
    let cube_dim = CubeDim::new(256, 1, 1);
    let cube_count = calculate_cube_count_elemwise(out_elems, cube_dim);

    weighted_sum_kernel::launch::<R>(
        &client,
        cube_count,
        cube_dim,
        weights.as_tensor_arg::<f32>(1),
        tokens.as_tensor_arg::<f32>(1),
        output.as_tensor_arg::<f32>(1),
    );
    output
}

#[cube(launch)]
fn weighted_sum_kernel(weights: &Tensor<f32>, tokens: &Tensor<f32>, output: &mut Tensor<f32>) {
    if ABSOLUTE_POS >= output.len() {
        terminate!();
    }
    let dim = output.shape(2);
    let out_tokens = output.shape(1);
    if dim == 0 || out_tokens == 0 {
        terminate!();
    }
    let d = ABSOLUTE_POS % dim;
    let pos = ABSOLUTE_POS / dim;
    let o = pos % out_tokens;
    let b = pos / out_tokens;
    let in_tokens = weights.shape(2);

    let weight_base = b * weights.stride(0) + o * weights.stride(1);
    let token_base = b * tokens.stride(0) + d * tokens.stride(2);

    let mut acc = 0.0f32;
    let mut i = 0u32;
    while i < in_tokens {
        let w_idx = weight_base + i * weights.stride(2);
        let t_idx = token_base + i * tokens.stride(1);
        acc += weights[w_idx] * tokens[t_idx];
        i += 1u32;
    }

    let out_idx = b * output.stride(0) + o * output.stride(1) + d * output.stride(2);
    output[out_idx] = acc;
}

fn matches_type<A: 'static, B: 'static>() -> bool {
    TypeId::of::<A>() == TypeId::of::<B>()
}

fn try_cast_primitive<B: BackendTrait, T: 'static>(value: B::FloatTensorPrimitive) -> Option<T>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed.downcast::<T>().ok().map(|boxed| *boxed)
}

fn try_cast_backend<B: BackendTrait, T: 'static>(value: T) -> Option<B::FloatTensorPrimitive>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed
        .downcast::<B::FloatTensorPrimitive>()
        .ok()
        .map(|boxed| *boxed)
}