burn_dragon_kernel 0.21.0-pre.13

Fused GPU kernel crate for burn_dragon execution paths
Documentation
const WORKGROUP_SIZE_X: u32 = 64u;
const MAX_FUSED_TIME: u32 = 1024u;

@group(0) @binding(0)
var<storage, read_write> query: array<f32>;

@group(0) @binding(1)
var<storage, read_write> value: array<f32>;

@group(0) @binding(2)
var<storage, read_write> context: array<f32>;

@group(0) @binding(3)
var<storage, read_write> decay: array<f32>;

@group(0) @binding(4)
var<storage, read_write> params: array<f32>;

var<workgroup> row_scores: array<f32, MAX_FUSED_TIME>;

fn to_u32(v: f32) -> u32 {
  return u32(v + 0.5);
}

fn idx_query(b: u32, h: u32, t: u32, l: u32, heads: u32, time: u32, latent: u32) -> u32 {
  return (((b * heads + h) * time + t) * latent + l);
}

fn idx_value(
  b: u32,
  vh: u32,
  t: u32,
  e: u32,
  value_heads: u32,
  time: u32,
  value_dim: u32,
) -> u32 {
  return (((b * value_heads + vh) * time + t) * value_dim + e);
}

fn idx_context(
  b: u32,
  h: u32,
  row: u32,
  e: u32,
  heads: u32,
  time: u32,
  value_dim: u32,
) -> u32 {
  return (((b * heads + h) * time + row) * value_dim + e);
}

@compute @workgroup_size(WORKGROUP_SIZE_X, 1, 1)
fn main(
  @builtin(global_invocation_id) gid: vec3<u32>,
  @builtin(local_invocation_id) lid: vec3<u32>,
  @builtin(workgroup_id) wid: vec3<u32>,
) {
  let batch = to_u32(params[0]);
  let heads = to_u32(params[1]);
  let value_heads = to_u32(params[2]);
  let time = to_u32(params[3]);
  let latent = to_u32(params[4]);
  let value_dim = to_u32(params[5]);

  let h = wid.y;
  let batch_row = wid.z;
  let b = batch_row / time;
  let row = batch_row % time;
  let e = gid.x;
  let lane = lid.x;

  if b >= batch || h >= heads || row >= time || time > MAX_FUSED_TIME {
    return;
  }

  let decay_value = decay[h];
  var col = lane;
  while col < row {
    var dot = 0.0;
    var l = 0u;
    while l < latent {
      let q_row = query[idx_query(b, h, row, l, heads, time, latent)];
      let q_col = query[idx_query(b, h, col, l, heads, time, latent)];
      dot += q_row * q_col;
      l += 1u;
    }
    row_scores[col] = dot * pow(decay_value, f32(row - col));
    col += WORKGROUP_SIZE_X;
  }
  workgroupBarrier();

  if e >= value_dim {
    return;
  }

  let value_head = select(h, 0u, value_heads == 1u);
  var acc = 0.0;
  col = 0u;
  while col < row {
    acc += row_scores[col] * value[idx_value(b, value_head, col, e, value_heads, time, value_dim)];
    col += 1u;
  }

  context[idx_context(b, h, row, e, heads, time, value_dim)] = acc;
}