brainharmony 0.1.0

Brain-Harmony multimodal brain foundation model — inference in Rust with Burn ML
Documentation
/// Non-causal flash attention via cubecl/cubek-attention.
///
/// Burn 0.20.1 hardcodes `causal: true` in its flash attention dispatch.
/// Brain-Harmony uses **bidirectional** (non-causal) attention (ViT, not GPT).
///
/// This module calls `cubek::attention::launch_ref` directly with `causal: false`.
#[cfg(any(feature = "wgpu", feature = "wgpu-f16"))]
pub mod gpu {
    use burn::prelude::*;
    use burn::tensor::TensorPrimitive;
    use burn_cubecl::tensor::CubeTensor;
    use burn_cubecl::{CubeRuntime, ops::numeric::empty_device_dtype};
    use burn_backend::{DType, Shape};
    use cubek::attention::{
        definition::{AccumulatorPrecision, AttentionGlobalTypes, AttentionOptions},
        launch::Strategy,
    };

    /// Run non-causal flash attention on the GPU via cubek-attention.
    ///
    /// q, k, v: CubeTensor with shape [B, H, N, Dh]
    /// Q should NOT be pre-scaled — this kernel handles scaling via the tiled algorithm.
    /// Returns: CubeTensor [B, H, N, Dh]
    pub fn flash_attention_noncausal<R: CubeRuntime>(
        query: CubeTensor<R>,
        key: CubeTensor<R>,
        value: CubeTensor<R>,
    ) -> CubeTensor<R> {
        let client = &query.client;
        let device = &query.device;
        let out_dtype = query.dtype;

        let num_batches = query.shape.dims[0];
        let num_heads = query.shape.dims[1];
        let seq_q = query.shape.dims[2];
        let val_dim = value.shape.dims[3];
        let out_shape = Shape::new([num_batches, num_heads, seq_q, val_dim]);

        let out = empty_device_dtype::<R>(client.clone(), device.clone(), out_shape, out_dtype);

        let dtypes = AttentionGlobalTypes {
            query: query.dtype.into(),
            key: key.dtype.into(),
            value: value.dtype.into(),
            mask: cubecl::ir::StorageType::Scalar(cubecl::ir::ElemType::UInt(
                cubecl::ir::UIntKind::U8,
            )),
            out: out.dtype.into(),
        };

        cubek::attention::launch::launch_ref::<R>(
            Strategy::Unit(cubek::attention::launch::BlueprintStrategy::Inferred(())),
            client,
            &query.as_handle_ref(),
            &key.as_handle_ref(),
            &value.as_handle_ref(),
            &None,
            &out.as_handle_ref(),
            &dtypes,
            AttentionOptions {
                causal: false,
                accumulator_precision: AccumulatorPrecision::Strict(
                    cubecl::ir::StorageType::Scalar(cubecl::ir::ElemType::Float(
                        cubecl::ir::FloatKind::F32,
                    )),
                ),
            },
        )
        .expect("non-causal flash attention kernel launch failed");

        out
    }

    /// High-level wrapper: takes Burn tensors on a CubeRuntime backend,
    /// converts to CubeTensor, runs flash attention, converts back.
    ///
    /// q, k, v: [B, H, N, Dh] — Q should NOT be pre-scaled.
    /// Returns: [B, H, N, Dh]
    pub fn flash_attention_tensor<B, R>(
        q: Tensor<B, 4>,
        k: Tensor<B, 4>,
        v: Tensor<B, 4>,
    ) -> Tensor<B, 4>
    where
        B: Backend<FloatTensorPrimitive = CubeTensor<R>>,
        R: CubeRuntime,
    {
        let q_prim = q.into_primitive().tensor();
        let k_prim = k.into_primitive().tensor();
        let v_prim = v.into_primitive().tensor();

        let out_prim = flash_attention_noncausal(q_prim, k_prim, v_prim);
        Tensor::from_primitive(TensorPrimitive::Float(out_prim))
    }
}