burn_dragon_kernel 0.5.0

Fused GPU kernel crate for burn_dragon execution paths
Documentation
// Generic local-grid rho attention kernel for 2D token lattices.
@group(0) @binding(0)
var<storage, read_write> query: array<f32>;

@group(0) @binding(1)
var<storage, read_write> value: array<f32>;

@group(0) @binding(2)
var<storage, read_write> rho_state: array<f32>;

@group(0) @binding(3)
var<storage, read_write> decay: array<f32>;

@group(0) @binding(4)
var<storage, read_write> context: array<f32>;

@group(0) @binding(5)
var<storage, read_write> rho_next: array<f32>;

@group(0) @binding(6)
var<storage, read_write> params: array<f32>;

fn to_u32(v: f32) -> u32 {
  return u32(v + 0.5);
}

fn idx_query(b: u32, h: u32, t: u32, l: u32, heads: u32, time: u32, latent: u32) -> u32 {
  return (((b * heads + h) * time + t) * latent + l);
}

fn idx_value(
  b: u32,
  h: u32,
  t: u32,
  e: u32,
  value_heads: u32,
  time: u32,
  embd: u32,
) -> u32 {
  return (((b * value_heads + h) * time + t) * embd + e);
}

fn idx_rho(
  b: u32,
  h: u32,
  t: u32,
  l: u32,
  e: u32,
  heads: u32,
  time: u32,
  latent: u32,
  embd: u32,
) -> u32 {
  return ((((b * heads + h) * time + t) * latent + l) * embd + e);
}

fn idx_context(b: u32, h: u32, t: u32, e: u32, heads: u32, time: u32, embd: u32) -> u32 {
  return (((b * heads + h) * time + t) * embd + e);
}

@compute @workgroup_size(8, 8, 1)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
  let batch = to_u32(params[0]);
  let heads = to_u32(params[1]);
  let value_heads = to_u32(params[2]);
  let patch_tokens = to_u32(params[3]);
  let latent = to_u32(params[4]);
  let embd = to_u32(params[5]);
  let grid_h = to_u32(params[6]);
  let grid_w = to_u32(params[7]);
  let radius = i32(to_u32(params[8]));
  let local_diagonals = to_u32(params[9]) == 1u;
  let local_self = to_u32(params[10]) == 1u;

  let e = gid.x;
  let dst_token = gid.y;
  let bh = gid.z;

  if bh >= batch * heads || dst_token >= patch_tokens || e >= embd {
    return;
  }

  let b = bh / heads;
  let h = bh % heads;
  let h_value = select(h, 0u, value_heads == 1u);
  let decay_value = decay[h];
  let ty = dst_token / grid_w;
  let tx = dst_token % grid_w;

  var acc = 0.0;
  var dy = -radius;
  loop {
    if dy > radius {
      break;
    }
    var dx = -radius;
    loop {
      if dx > radius {
        break;
      }
      if !(dy == 0 && dx == 0 && !local_self) && !(!local_diagonals && dy != 0 && dx != 0) {
        let sy = i32(ty) + dy;
        let sx = i32(tx) + dx;
        if sy >= 0 && sy < i32(grid_h) && sx >= 0 && sx < i32(grid_w) {
          let source = u32(sy) * grid_w + u32(sx);
          if source < patch_tokens {
            var l = 0u;
            while l < latent {
              let query_index = idx_query(b, h, dst_token, l, heads, patch_tokens, latent);
              let rho_index = idx_rho(b, h, source, l, e, heads, patch_tokens, latent, embd);
              acc += rho_state[rho_index] * query[query_index];
              l += 1u;
            }
          }
        }
      }
      dx += 1;
    }
    dy += 1;
  }

  let out_index = idx_context(b, h, dst_token, e, heads, patch_tokens, embd);
  context[out_index] = acc;

  let value_index = idx_value(b, h_value, dst_token, e, value_heads, patch_tokens, embd);
  let value_t = value[value_index];
  var l = 0u;
  while l < latent {
    let query_index = idx_query(b, h, dst_token, l, heads, patch_tokens, latent);
    let rho_index = idx_rho(b, h, dst_token, l, e, heads, patch_tokens, latent, embd);
    let rho_prev = rho_state[rho_index];
    rho_next[rho_index] = rho_prev * decay_value + query[query_index] * value_t;
    l += 1u;
  }
}