burn_dragon_vision 0.4.0

Foveation and vision sampling utilities for burn dragon
Documentation
use std::any::{Any, TypeId};

use burn::tensor::Tensor as BurnTensor;
use burn::tensor::backend::Backend as BackendTrait;
use burn::tensor::{DType, Shape, TensorData, TensorPrimitive};
use burn_cubecl::fusion::FusionCubeRuntime;
use burn_cubecl::kernel::into_contiguous;
use burn_cubecl::ops::numeric::empty_device;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime};
use burn_fusion::FusionTensor;
use burn_fusion::stream::StreamId;
use burn_wgpu::{KernelSource, SourceKernel, SourceTemplate, WgpuRuntime};
use cubecl::prelude::*;
use cubecl_runtime::server::Bindings;

use crate::SCATTER_BUFFER_SHADER;

const META_LEN: usize = 4;
const WORKGROUP_SIZE: u32 = 8;

pub(crate) fn supports_backend<B: BackendTrait>() -> bool
where
    B::FloatTensorPrimitive: 'static,
{
    matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<WgpuRuntime, u32>>>()
        || matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<WgpuRuntime, u8>>>(
        )
        || matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>()
}

pub(crate) fn try_weighted_sum_tokens_wgsl<B: BackendTrait>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B::FloatTensorPrimitive: 'static,
{
    if !supports_backend::<B>() {
        return None;
    }
    let [batch, out_tokens, in_tokens] = weights.shape().dims::<3>();
    let [tok_batch, tok_in, dim] = tokens.shape().dims::<3>();
    if batch == 0 || out_tokens == 0 || in_tokens == 0 || dim == 0 {
        return None;
    }
    if batch != tok_batch || in_tokens != tok_in {
        return None;
    }

    if let Some(result) = try_weighted_sum_tokens_wgsl_fusion::<B, u32>(weights, tokens) {
        return Some(result);
    }
    if let Some(result) = try_weighted_sum_tokens_wgsl_fusion::<B, u8>(weights, tokens) {
        return Some(result);
    }
    try_weighted_sum_tokens_wgsl_direct::<B>(weights, tokens)
}

fn try_weighted_sum_tokens_wgsl_fusion<B, BT>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<WgpuRuntime, BT>>>()
    {
        return None;
    }
    let device = weights.device();
    let prim_weights = weights.clone().into_primitive().tensor();
    let fusion_weights: FusionTensor<FusionCubeRuntime<WgpuRuntime, BT>> =
        try_cast_primitive::<B, _>(prim_weights)?;
    let fusion_client = fusion_weights.client.clone();
    let weights = fusion_client
        .resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_weights);
    if weights.dtype != DType::F32 {
        return None;
    }

    let prim_tokens = tokens.clone().into_primitive().tensor();
    let fusion_tokens: FusionTensor<FusionCubeRuntime<WgpuRuntime, BT>> =
        try_cast_primitive::<B, _>(prim_tokens)?;
    let tokens =
        fusion_client.resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion_tokens);
    if tokens.dtype != DType::F32 {
        return None;
    }

    let meta = build_meta::<B>(weights.shape.dims::<3>(), tokens.shape.dims::<3>(), device);
    let meta = resolve_fusion_tensor::<B, BT>(&meta)?;
    let output = weighted_sum_tokens_wgsl_runtime::<WgpuRuntime>(weights, tokens, meta);
    let shape = output.shape.clone();
    let dtype = output.dtype;
    let handle = output.into();
    let fusion_out = fusion_client.register_tensor(handle, shape, StreamId::current(), dtype);
    let out_prim = try_cast_backend::<B, _>(fusion_out)?;
    Some(BurnTensor::<B, 3>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn try_weighted_sum_tokens_wgsl_direct<B: BackendTrait>(
    weights: &BurnTensor<B, 3>,
    tokens: &BurnTensor<B, 3>,
) -> Option<BurnTensor<B, 3>>
where
    B::FloatTensorPrimitive: 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>() {
        return None;
    }
    let device = weights.device();
    let prim_weights = weights.clone().into_primitive().tensor();
    let weights: CubeTensor<WgpuRuntime> = try_cast_primitive::<B, _>(prim_weights)?;
    if weights.dtype != DType::F32 {
        return None;
    }
    let prim_tokens = tokens.clone().into_primitive().tensor();
    let tokens: CubeTensor<WgpuRuntime> = try_cast_primitive::<B, _>(prim_tokens)?;
    if tokens.dtype != DType::F32 {
        return None;
    }

    let meta = build_meta::<B>(weights.shape.dims::<3>(), tokens.shape.dims::<3>(), device);
    let meta = resolve_direct_tensor::<B>(&meta)?;
    let output = weighted_sum_tokens_wgsl_runtime::<WgpuRuntime>(weights, tokens, meta);
    let out_prim = try_cast_backend::<B, _>(output)?;
    Some(BurnTensor::<B, 3>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn build_meta<B: BackendTrait>(
    weights_shape: [usize; 3],
    tokens_shape: [usize; 3],
    device: B::Device,
) -> BurnTensor<B, 1> {
    let meta = [
        weights_shape[0] as f32,
        weights_shape[1] as f32,
        weights_shape[2] as f32,
        tokens_shape[2] as f32,
    ];
    BurnTensor::<B, 1>::from_data(TensorData::new(meta.to_vec(), [META_LEN]), &device)
}

fn weighted_sum_tokens_wgsl_runtime<R: CubeRuntime>(
    weights: CubeTensor<R>,
    tokens: CubeTensor<R>,
    meta: CubeTensor<R>,
) -> CubeTensor<R> {
    let weights = into_contiguous(weights);
    let tokens = into_contiguous(tokens);
    let meta = into_contiguous(meta);

    let [batch, out_tokens, _] = weights.shape.dims::<3>();
    let dim = tokens.shape.dims::<3>()[2];
    let client = weights.client.clone();
    let device = weights.device.clone();
    let output =
        empty_device::<R, f32>(client.clone(), device, Shape::new([batch, out_tokens, dim]));

    let workgroups_x = div_ceil_u32(out_tokens as u32, WORKGROUP_SIZE);
    let workgroups_y = div_ceil_u32(dim as u32, WORKGROUP_SIZE);
    let count = CubeCount::Static(workgroups_x, workgroups_y, batch as u32);

    let kernel = SourceKernel::new(
        ScatterBufferKernel,
        CubeDim::new(WORKGROUP_SIZE, WORKGROUP_SIZE, 1),
    );
    let bindings = Bindings::new().with_buffers(vec![
        weights.handle.clone().binding(),
        tokens.handle.clone().binding(),
        output.handle.clone().binding(),
        meta.handle.clone().binding(),
    ]);
    client.execute(Box::new(kernel), count, bindings);
    output
}

fn resolve_fusion_tensor<B, BT>(
    tensor: &BurnTensor<B, 1>,
) -> Option<CubeTensor<WgpuRuntime>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
{
    let prim = tensor.clone().into_primitive().tensor();
    let fusion: FusionTensor<FusionCubeRuntime<WgpuRuntime, BT>> =
        try_cast_primitive::<B, _>(prim)?;
    let client = fusion.client.clone();
    let cube = client.resolve_tensor_float::<CubeBackend<WgpuRuntime, f32, i32, BT>>(fusion);
    if cube.dtype != DType::F32 {
        return None;
    }
    Some(cube)
}

fn resolve_direct_tensor<B: BackendTrait>(
    tensor: &BurnTensor<B, 1>,
) -> Option<CubeTensor<WgpuRuntime>>
where
    B::FloatTensorPrimitive: 'static,
{
    let prim = tensor.clone().into_primitive().tensor();
    let cube: CubeTensor<WgpuRuntime> = try_cast_primitive::<B, _>(prim)?;
    if cube.dtype != DType::F32 {
        return None;
    }
    Some(cube)
}

fn div_ceil_u32(value: u32, divisor: u32) -> u32 {
    value.div_ceil(divisor)
}

#[derive(Clone)]
struct ScatterBufferKernel;

impl KernelSource for ScatterBufferKernel {
    fn source(&self) -> SourceTemplate {
        SourceTemplate::new(SCATTER_BUFFER_SHADER)
    }

    fn id(&self) -> KernelId {
        KernelId::new::<Self>()
    }
}

fn matches_type<A: 'static, B: 'static>() -> bool {
    TypeId::of::<A>() == TypeId::of::<B>()
}

fn try_cast_primitive<B: BackendTrait, T: 'static>(value: B::FloatTensorPrimitive) -> Option<T>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed.downcast::<T>().ok().map(|boxed| *boxed)
}

fn try_cast_backend<B: BackendTrait, T: 'static>(value: T) -> Option<B::FloatTensorPrimitive>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed
        .downcast::<B::FloatTensorPrimitive>()
        .ok()
        .map(|boxed| *boxed)
}