vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
@group(0) @binding(0) var<storage, read> input_words: array<u32>;
@group(0) @binding(1) var<storage, read_write> output_words: array<u32>;

const MD5_S: array<u32, 64> = array<u32, 64>(
  7u, 12u, 17u, 22u, 7u, 12u, 17u, 22u, 7u, 12u, 17u, 22u, 7u, 12u, 17u, 22u,
  5u, 9u, 14u, 20u, 5u, 9u, 14u, 20u, 5u, 9u, 14u, 20u, 5u, 9u, 14u, 20u,
  4u, 11u, 16u, 23u, 4u, 11u, 16u, 23u, 4u, 11u, 16u, 23u, 4u, 11u, 16u, 23u,
  6u, 10u, 15u, 21u, 6u, 10u, 15u, 21u, 6u, 10u, 15u, 21u, 6u, 10u, 15u, 21u);

const MD5_K: array<u32, 64> = array<u32, 64>(
  0xd76aa478u, 0xe8c7b756u, 0x242070dbu, 0xc1bdceeeu, 0xf57c0fafu, 0x4787c62au, 0xa8304613u, 0xfd469501u,
  0x698098d8u, 0x8b44f7afu, 0xffff5bb1u, 0x895cd7beu, 0x6b901122u, 0xfd987193u, 0xa679438eu, 0x49b40821u,
  0xf61e2562u, 0xc040b340u, 0x265e5a51u, 0xe9b6c7aau, 0xd62f105du, 0x02441453u, 0xd8a1e681u, 0xe7d3fbc8u,
  0x21e1cde6u, 0xc33707d6u, 0xf4d50d87u, 0x455a14edu, 0xa9e3e905u, 0xfcefa3f8u, 0x676f02d9u, 0x8d2a4c8au,
  0xfffa3942u, 0x8771f681u, 0x6d9d6122u, 0xfde5380cu, 0xa4beea44u, 0x4bdecfa9u, 0xf6bb4b60u, 0xbebfbc70u,
  0x289b7ec6u, 0xeaa127fau, 0xd4ef3085u, 0x04881d05u, 0xd9d4d039u, 0xe6db99e5u, 0x1fa27cf8u, 0xc4ac5665u,
  0xf4292244u, 0x432aff97u, 0xab9423a7u, 0xfc93a039u, 0x655b59c3u, 0x8f0ccc92u, 0xffeff47du, 0x85845dd1u,
  0x6fa87e4fu, 0xfe2ce6e0u, 0xa3014314u, 0x4e0811a1u, 0xf7537e82u, 0xbd3af235u, 0x2ad7d2bbu, 0xeb86d391u);

fn md5_rotl(x: u32, n: u32) -> u32 {
  return (x << n) | (x >> ((32u - n) & 31u));
}

fn md5_step_f(i: u32, b: u32, c: u32, d: u32) -> u32 {
  if (i < 16u) { return (b & c) | ((~b) & d); }
  if (i < 32u) { return (d & b) | ((~d) & c); }
  if (i < 48u) { return b ^ c ^ d; }
  return c ^ (b | (~d));
}

fn md5_step_g(i: u32) -> u32 {
  if (i < 16u) { return i; }
  if (i < 32u) { return (5u * i + 1u) & 15u; }
  if (i < 48u) { return (3u * i + 5u) & 15u; }
  return (7u * i) & 15u;
}

fn md5_be_word(x: u32) -> u32 {
  return ((x & 0x000000ffu) << 24u) | ((x & 0x0000ff00u) << 8u) | ((x >> 8u) & 0x0000ff00u) | ((x >> 24u) & 0x000000ffu);
}

@compute @workgroup_size(1, 1, 1)
fn hash_md5(@builtin(global_invocation_id) id: vec3<u32>) {
  if (id.x != 0u) { return; }
  var block: array<u32, 16>;
  for (var i = 0u; i < 16u; i = i + 1u) { block[i] = 0u; }
  let words = min(arrayLength(&input_words), 14u);
  for (var i = 0u; i < words; i = i + 1u) { block[i] = input_words[i]; }
  let byte_len = words * 4u;
  if (words < 16u) {
    let lane = byte_len >> 2u;
    let shift = (byte_len & 3u) * 8u;
    block[lane] = block[lane] | (0x80u << shift);
  }
  block[14u] = byte_len * 8u;
  block[15u] = 0u;

  var a = 0x67452301u;
  var b = 0xefcdab89u;
  var c = 0x98badcfeu;
  var d = 0x10325476u;
  let aa = a;
  let bb = b;
  let cc = c;
  let dd = d;
  for (var i = 0u; i < 64u; i = i + 1u) {
    let f = md5_step_f(i, b, c, d);
    let g = md5_step_g(i);
    let next = b + md5_rotl(a + f + MD5_K[i] + block[g], MD5_S[i]);
    a = d;
    d = c;
    c = b;
    b = next;
  }
  output_words[0u] = md5_be_word(aa + a);
  output_words[1u] = md5_be_word(bb + b);
  output_words[2u] = md5_be_word(cc + c);
  output_words[3u] = md5_be_word(dd + d);
}