burn_dragon_core 0.4.0

burn dragon core model and utilities
Documentation
use burn::tensor::Tensor;
use burn::tensor::activation;
use burn::tensor::backend::Backend;

use super::block_sparse::BlockPattern1d;

pub fn fused_forward<B: Backend>(
    input: Tensor<B, 4>,
    weight: Tensor<B, 4>,
    bias: Option<Tensor<B, 3>>,
    threshold: f32,
    layout: &BlockPattern1d,
) -> Tensor<B, 4> {
    let device = input.device();
    let latent = weight.shape().dims::<4>()[3];

    let mut projected = input.matmul(weight);

    if let Some(bias) = bias {
        let dims = bias.shape().dims::<3>();
        let bias = bias.reshape([1, dims[0], 1, dims[2]]);
        projected = projected + bias;
    }

    if threshold != 0.0 {
        projected = projected.sub_scalar(threshold);
    }

    let mut activated = activation::relu(projected);

    if layout.is_sparse() {
        let mask = layout.mask::<B>(latent, &device);
        activated = activated * mask;
    }

    activated
}