burn_dragon_train 0.4.0

Training utilities for burn_dragon
Documentation
use std::any::{Any, TypeId};

use burn::tensor::Tensor as BurnTensor;
use burn::tensor::backend::Backend as BackendTrait;
use burn::tensor::{DType, Shape, TensorPrimitive};
use burn_cubecl::fusion::FusionCubeRuntime;
use burn_cubecl::kernel::into_contiguous;
use burn_cubecl::ops::numeric::empty_device;
use burn_cubecl::tensor::CubeTensor;
use burn_cubecl::{BoolElement, CubeBackend, CubeRuntime};
use burn_fusion::FusionTensor;
use burn_fusion::stream::StreamId;
use burn_wgpu::WgpuRuntime;
#[cfg(feature = "cuda")]
use cubecl::cuda::CudaRuntime;
use cubecl::{calculate_cube_count_elemwise, prelude::*};

pub const MAX_GROUP: u32 = 8;

pub fn supports_backend<B: BackendTrait>() -> bool
where
    B::FloatTensorPrimitive: 'static,
{
    #[cfg(feature = "cuda")]
    {
        matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u32>>,
        >() || matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u8>>,
        >() || matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>()
            || matches_type::<
                B::FloatTensorPrimitive,
                FusionTensor<FusionCubeRuntime<CudaRuntime, u32>>,
            >()
            || matches_type::<
                B::FloatTensorPrimitive,
                FusionTensor<FusionCubeRuntime<CudaRuntime, u8>>,
            >()
            || matches_type::<B::FloatTensorPrimitive, CubeTensor<CudaRuntime>>()
    }
    #[cfg(not(feature = "cuda"))]
    {
        matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u32>>,
        >() || matches_type::<
            B::FloatTensorPrimitive,
            FusionTensor<FusionCubeRuntime<WgpuRuntime, u8>>,
        >() || matches_type::<B::FloatTensorPrimitive, CubeTensor<WgpuRuntime>>()
    }
}

pub fn try_percentile_thresholds_cubecl<B: BackendTrait>(
    values: &BurnTensor<B, 2>,
    quantile: f32,
) -> Option<BurnTensor<B, 2>>
where
    B::FloatTensorPrimitive: 'static,
{
    if !supports_backend::<B>() {
        return None;
    }
    let [batch, group] = values.shape().dims::<2>();
    if batch == 0 || group == 0 {
        return Some(BurnTensor::<B, 2>::zeros(
            [batch.max(1), 1],
            &values.device(),
        ));
    }
    if group > MAX_GROUP as usize {
        return None;
    }
    let quantile = if quantile.is_nan() {
        0.0
    } else {
        quantile.clamp(0.0, 1.0)
    };

    if let Some(result) =
        try_percentile_thresholds_cubecl_fusion::<B, u32, WgpuRuntime>(values, quantile)
    {
        return Some(result);
    }
    if let Some(result) =
        try_percentile_thresholds_cubecl_fusion::<B, u8, WgpuRuntime>(values, quantile)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    if let Some(result) =
        try_percentile_thresholds_cubecl_fusion::<B, u32, CudaRuntime>(values, quantile)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    if let Some(result) =
        try_percentile_thresholds_cubecl_fusion::<B, u8, CudaRuntime>(values, quantile)
    {
        return Some(result);
    }
    #[cfg(feature = "cuda")]
    {
        if let Some(result) =
            try_percentile_thresholds_cubecl_direct::<B, CudaRuntime>(values, quantile)
        {
            return Some(result);
        }
    }
    try_percentile_thresholds_cubecl_direct::<B, WgpuRuntime>(values, quantile)
}

fn try_percentile_thresholds_cubecl_fusion<B, BT, R>(
    values: &BurnTensor<B, 2>,
    quantile: f32,
) -> Option<BurnTensor<B, 2>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    BT: BoolElement + 'static,
    R: CubeRuntime + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, FusionTensor<FusionCubeRuntime<R, BT>>>() {
        return None;
    }
    let prim_values = values.clone().into_primitive().tensor();
    let fusion_values: FusionTensor<FusionCubeRuntime<R, BT>> =
        try_cast_primitive::<B, _>(prim_values)?;
    let fusion_client = fusion_values.client.clone();
    let values = fusion_client.resolve_tensor_float::<CubeBackend<R, f32, i32, BT>>(fusion_values);
    if values.dtype != DType::F32 {
        return None;
    }

    let output = percentile_thresholds_cubecl_runtime::<R>(values, quantile);
    let shape = output.shape.clone();
    let dtype = output.dtype;
    let handle = output.into();
    let fusion_out = fusion_client.register_tensor(handle, shape, StreamId::current(), dtype);
    let out_prim = try_cast_backend::<B, _>(fusion_out)?;
    Some(BurnTensor::<B, 2>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn try_percentile_thresholds_cubecl_direct<B, R>(
    values: &BurnTensor<B, 2>,
    quantile: f32,
) -> Option<BurnTensor<B, 2>>
where
    B: BackendTrait,
    B::FloatTensorPrimitive: 'static,
    R: CubeRuntime + 'static,
{
    if !matches_type::<B::FloatTensorPrimitive, CubeTensor<R>>() {
        return None;
    }
    let prim_values = values.clone().into_primitive().tensor();
    let values: CubeTensor<R> = try_cast_primitive::<B, _>(prim_values)?;
    if values.dtype != DType::F32 {
        return None;
    }

    let output = percentile_thresholds_cubecl_runtime::<R>(values, quantile);
    let out_prim = try_cast_backend::<B, _>(output)?;
    Some(BurnTensor::<B, 2>::from_primitive(TensorPrimitive::Float(
        out_prim,
    )))
}

fn percentile_thresholds_cubecl_runtime<R: CubeRuntime>(
    values: CubeTensor<R>,
    quantile: f32,
) -> CubeTensor<R> {
    let values = into_contiguous(values);
    let [batch, _group] = values.shape.dims::<2>();

    let client = values.client.clone();
    let device = values.device.clone();
    let output = empty_device::<R, f32>(client.clone(), device, Shape::new([batch, 1]));
    let out_elems = output.shape.num_elements();
    let cube_dim = CubeDim::new(1, 1, 1);
    let cube_count = calculate_cube_count_elemwise(out_elems, cube_dim);

    percentile_thresholds_kernel::launch::<R>(
        &client,
        cube_count,
        cube_dim,
        values.as_tensor_arg::<f32>(1),
        output.as_tensor_arg::<f32>(1),
        ScalarArg::new(quantile),
    );

    output
}

#[cube(launch)]
fn percentile_thresholds_kernel(values: &Tensor<f32>, output: &mut Tensor<f32>, quantile: f32) {
    if ABSOLUTE_POS >= output.len() {
        terminate!();
    }
    let batch = values.shape(0);
    let group = values.shape(1);
    if batch == 0 || group == 0 {
        terminate!();
    }
    if group > MAX_GROUP {
        let out_idx = ABSOLUTE_POS * output.stride(0);
        output[out_idx] = 0.0f32;
        terminate!();
    }

    let mut scratch = SharedMemory::<f32>::new(MAX_GROUP);
    let base = ABSOLUTE_POS * values.stride(0);
    let mut i = 0u32;
    while i < group {
        let idx = base + i * values.stride(1);
        let mut value = values[idx];
        if <f32 as IsNan>::is_nan(value) {
            value = 0.0f32;
        } else if <f32 as IsInf>::is_inf(value) {
            value = if value < 0.0f32 {
                f32::new(-3.4028235e38f32)
            } else {
                f32::new(3.4028235e38f32)
            };
        }
        scratch[i] = value;
        i += 1u32;
    }

    let out_idx = ABSOLUTE_POS * output.stride(0);
    if group == 1u32 {
        output[out_idx] = scratch[0];
        terminate!();
    }
    if group == 2u32 {
        let a = scratch[0];
        let b = scratch[1];
        let lo = if a < b { a } else { b };
        let hi = if a < b { b } else { a };
        output[out_idx] = lo + (hi - lo) * quantile;
        terminate!();
    }

    let mut i = 1u32;
    while i < group {
        let mut j = i;
        while j > 0u32 {
            let prev = scratch[j - 1u32];
            let curr = scratch[j];
            if curr < prev {
                scratch[j - 1u32] = curr;
                scratch[j] = prev;
            }
            j -= 1u32;
        }
        i += 1u32;
    }

    let pos = (group - 1u32) as f32 * quantile;
    let lower = pos as u32;
    let upper = if lower + 1u32 < group {
        lower + 1u32
    } else {
        lower
    };
    let weight = pos - lower as f32;
    let lower_val = scratch[lower];
    let upper_val = scratch[upper];
    output[out_idx] = lower_val + (upper_val - lower_val) * weight;
}

fn matches_type<A: 'static, B: 'static>() -> bool {
    TypeId::of::<A>() == TypeId::of::<B>()
}

fn try_cast_primitive<B: BackendTrait, T: 'static>(value: B::FloatTensorPrimitive) -> Option<T>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed.downcast::<T>().ok().map(|boxed| *boxed)
}

fn try_cast_backend<B: BackendTrait, T: 'static>(value: T) -> Option<B::FloatTensorPrimitive>
where
    B::FloatTensorPrimitive: 'static,
{
    let boxed: Box<dyn Any> = Box::new(value);
    boxed
        .downcast::<B::FloatTensorPrimitive>()
        .ok()
        .map(|boxed| *boxed)
}