struct LocalTransform {
translation: vec3<f32>,
parent_index: i32,
rotation: vec4<f32>,
scale: vec3<f32>,
depth: u32,
};
struct ModelMatrix {
model: mat4x4<f32>,
normal_matrix: mat3x3<f32>,
};
struct Uniforms {
entity_count: u32,
current_depth: u32,
_pad: vec2<u32>,
};
@group(0) @binding(0) var<storage, read> local_transforms: array<LocalTransform>;
@group(0) @binding(1) var<storage, read_write> model_matrices: array<ModelMatrix>;
@group(0) @binding(2) var<uniform> uniforms: Uniforms;
fn quat_to_mat3(q: vec4<f32>) -> mat3x3<f32> {
let x2 = q.x + q.x;
let y2 = q.y + q.y;
let z2 = q.z + q.z;
let xx = q.x * x2;
let xy = q.x * y2;
let xz = q.x * z2;
let yy = q.y * y2;
let yz = q.y * z2;
let zz = q.z * z2;
let wx = q.w * x2;
let wy = q.w * y2;
let wz = q.w * z2;
return mat3x3<f32>(
vec3<f32>(1.0 - yy - zz, xy + wz, xz - wy),
vec3<f32>(xy - wz, 1.0 - xx - zz, yz + wx),
vec3<f32>(xz + wy, yz - wx, 1.0 - xx - yy)
);
}
fn compute_normal_matrix(m: mat4x4<f32>) -> mat3x3<f32> {
let m3 = mat3x3<f32>(m[0].xyz, m[1].xyz, m[2].xyz);
let sx = dot(m3[0], m3[0]);
let sy = dot(m3[1], m3[1]);
let sz = dot(m3[2], m3[2]);
let is_uniform = abs(sx - sy) < 0.001 && abs(sy - sz) < 0.001;
if is_uniform {
let inv = inverseSqrt(sx);
return mat3x3<f32>(m3[0] * inv, m3[1] * inv, m3[2] * inv);
}
let inv = inverseSqrt(vec3<f32>(sx, sy, sz));
return mat3x3<f32>(m3[0] * inv.x, m3[1] * inv.y, m3[2] * inv.z);
}
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3<u32>) {
let index = id.x;
if index >= uniforms.entity_count {
return;
}
let local = local_transforms[index];
if local.depth != uniforms.current_depth {
return;
}
let rot = quat_to_mat3(local.rotation);
let scaled = mat3x3<f32>(
rot[0] * local.scale.x,
rot[1] * local.scale.y,
rot[2] * local.scale.z
);
var world = mat4x4<f32>(
vec4<f32>(scaled[0], 0.0),
vec4<f32>(scaled[1], 0.0),
vec4<f32>(scaled[2], 0.0),
vec4<f32>(local.translation, 1.0)
);
if local.parent_index >= 0 {
let parent = model_matrices[local.parent_index];
world = parent.model * world;
}
model_matrices[index].model = world;
model_matrices[index].normal_matrix = compute_normal_matrix(world);
}