use crate::ecs::animation::components::{
AnimationChannel, AnimationInterpolation, AnimationProperty, AnimationSamplerOutput,
};
use crate::ecs::skin::systems::SkinningCache;
use crate::ecs::world::{ANIMATION_PLAYER, Entity, World};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use wgpu::util::DeviceExt;
const PROPERTY_TRANSLATION: u32 = 0;
const PROPERTY_ROTATION: u32 = 1;
const PROPERTY_SCALE: u32 = 2;
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuAnimBone {
output_index: u32,
parent_output_index: i32,
cur_channel_start: u32,
cur_channel_count: u32,
blend_channel_start: u32,
blend_channel_count: u32,
pad0: u32,
pad1: u32,
rest_translation: [f32; 4],
rest_rotation: [f32; 4],
rest_scale: [f32; 4],
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuAnimChannel {
property: u32,
interpolation: u32,
input_offset: u32,
key_count: u32,
output_offset: u32,
output_stride: u32,
pad0: u32,
pad1: u32,
}
#[repr(C)]
#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
struct GpuAnimSkeleton {
bone_start: u32,
bone_count: u32,
time: f32,
blend_from_time: f32,
blend_factor: f32,
pad0: u32,
pad1: u32,
pad2: u32,
armature_root: [[f32; 4]; 4],
}
struct OwnedChannel {
property: u32,
interpolation: u32,
input: Vec<f32>,
values: Vec<[f32; 4]>,
stride: u32,
}
struct PlayerRuntime {
time: f32,
blend_from_time: f32,
blend_factor: f32,
}
struct SkeletonLayout {
skin_entity: Entity,
player_entity: Entity,
bone_start: u32,
bone_count: u32,
}
#[derive(Default)]
struct FlatChannels {
channels: Vec<GpuAnimChannel>,
times: Vec<f32>,
values: Vec<[f32; 4]>,
}
impl FlatChannels {
fn push_owned(&mut self, channel: &OwnedChannel) {
let input_offset = self.times.len() as u32;
self.times.extend_from_slice(&channel.input);
let output_offset = self.values.len() as u32;
self.values.extend_from_slice(&channel.values);
self.channels.push(GpuAnimChannel {
property: channel.property,
interpolation: channel.interpolation,
input_offset,
key_count: channel.input.len() as u32,
output_offset,
output_stride: channel.stride,
pad0: 0,
pad1: 0,
});
}
}
fn build_owned_channel(channel: &AnimationChannel) -> Option<OwnedChannel> {
let property = gpu_property(channel.target_property)?;
let sampler = &channel.sampler;
let interpolation = match sampler.interpolation {
AnimationInterpolation::Linear => 0,
AnimationInterpolation::Step => 1,
AnimationInterpolation::CubicSpline => 2,
};
let mut values: Vec<[f32; 4]> = Vec::new();
let stride = match &sampler.output {
AnimationSamplerOutput::Vec3(samples) => {
for value in samples {
values.push([value.x, value.y, value.z, 0.0]);
}
1
}
AnimationSamplerOutput::Quat(samples) => {
for value in samples {
let coords = value.coords;
values.push([coords.x, coords.y, coords.z, coords.w]);
}
1
}
AnimationSamplerOutput::CubicSplineVec3 {
values: samples,
in_tangents,
out_tangents,
} => {
for index in 0..samples.len() {
let in_tangent = in_tangents[index];
let value = samples[index];
let out_tangent = out_tangents[index];
values.push([in_tangent.x, in_tangent.y, in_tangent.z, 0.0]);
values.push([value.x, value.y, value.z, 0.0]);
values.push([out_tangent.x, out_tangent.y, out_tangent.z, 0.0]);
}
3
}
AnimationSamplerOutput::CubicSplineQuat {
values: samples,
in_tangents,
out_tangents,
} => {
for index in 0..samples.len() {
let in_tangent = in_tangents[index].coords;
let value = samples[index].coords;
let out_tangent = out_tangents[index].coords;
values.push([in_tangent.x, in_tangent.y, in_tangent.z, in_tangent.w]);
values.push([value.x, value.y, value.z, value.w]);
values.push([out_tangent.x, out_tangent.y, out_tangent.z, out_tangent.w]);
}
3
}
AnimationSamplerOutput::Weights(_) | AnimationSamplerOutput::CubicSplineWeights { .. } => {
return None;
}
};
Some(OwnedChannel {
property,
interpolation,
input: sampler.input.clone(),
values,
stride,
})
}
pub(super) struct SkinnedAnimationGpu {
pipeline: wgpu::ComputePipeline,
bind_group_layout: wgpu::BindGroupLayout,
bind_group: Option<wgpu::BindGroup>,
skeletons_buffer: wgpu::Buffer,
bones_buffer: wgpu::Buffer,
channels_buffer: wgpu::Buffer,
times_buffer: wgpu::Buffer,
values_buffer: wgpu::Buffer,
skeletons_capacity: usize,
layouts: Vec<SkeletonLayout>,
cached_static_signature: u64,
bone_transforms_generation: u64,
}
fn storage_buffer(device: &wgpu::Device, label: &str, data: &[u8]) -> wgpu::Buffer {
let fallback = [0u8; 16];
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: if data.is_empty() { &fallback } else { data },
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
})
}
impl SkinnedAnimationGpu {
pub(super) fn new(device: &wgpu::Device) -> Self {
let shader = crate::render::wgpu::shader_compose::compile_wgsl(
device,
"animation_compute.wgsl",
include_str!("../../../shaders/animation_compute.wgsl"),
);
let entries: Vec<wgpu::BindGroupLayoutEntry> = (0..6)
.map(|binding| wgpu::BindGroupLayoutEntry {
binding,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage {
read_only: binding != 5,
},
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
})
.collect();
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Animation Compute Bind Group Layout"),
entries: &entries,
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Animation Compute Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Animation Compute Pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Self {
pipeline,
bind_group_layout,
bind_group: None,
skeletons_buffer: storage_buffer(device, "Anim Skeletons", &[]),
bones_buffer: storage_buffer(device, "Anim Bones", &[]),
channels_buffer: storage_buffer(device, "Anim Channels", &[]),
times_buffer: storage_buffer(device, "Anim Times", &[]),
values_buffer: storage_buffer(device, "Anim Values", &[]),
skeletons_capacity: 0,
layouts: Vec::new(),
cached_static_signature: 0,
bone_transforms_generation: u64::MAX,
}
}
pub(super) fn skeleton_count(&self) -> u32 {
self.layouts.len() as u32
}
pub(super) fn dispatch(&self, compute_pass: &mut wgpu::ComputePass) {
if self.layouts.is_empty() {
return;
}
if let Some(bind_group) = self.bind_group.as_ref() {
compute_pass.set_pipeline(&self.pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups((self.layouts.len() as u32).div_ceil(64), 1, 1);
}
}
pub(super) fn update(
&mut self,
device: &wgpu::Device,
queue: &wgpu::Queue,
world: &World,
skinning_cache: &SkinningCache,
bone_transforms_buffer: &wgpu::Buffer,
bone_transforms_generation: u64,
) {
let mut joint_to_player: HashMap<Entity, Entity> = HashMap::new();
let mut player_runtime: HashMap<Entity, PlayerRuntime> = HashMap::new();
let mut player_clips: HashMap<Entity, (i64, i64)> = HashMap::new();
world
.core
.query()
.with(ANIMATION_PLAYER)
.iter(|player_entity, table, idx| {
let player = &table.animation_player[idx];
if !player.playing || player.play_all {
return;
}
let Some(current) = player.get_current_clip() else {
return;
};
player_runtime.insert(
player_entity,
PlayerRuntime {
time: player.time,
blend_from_time: player.blend_from_time,
blend_factor: player.blend_factor,
},
);
player_clips.insert(
player_entity,
(
player.current_clip.map(|index| index as i64).unwrap_or(-1),
player
.blend_from_clip
.map(|index| index as i64)
.unwrap_or(-1),
),
);
for channel in ¤t.channels {
if gpu_property(channel.target_property).is_some()
&& let Some(target) = player.resolve_target_entity(channel)
{
joint_to_player.entry(target).or_insert(player_entity);
}
}
});
let signature = static_signature(world, skinning_cache, &joint_to_player, &player_clips);
if signature != self.cached_static_signature
|| self.bone_transforms_generation != bone_transforms_generation
{
self.rebuild_static(device, world, skinning_cache, &joint_to_player);
self.cached_static_signature = signature;
self.bone_transforms_generation = bone_transforms_generation;
self.bind_group = Some(self.create_bind_group(device, bone_transforms_buffer));
}
self.upload_skeletons(queue, world, &player_runtime);
}
fn rebuild_static(
&mut self,
device: &wgpu::Device,
world: &World,
skinning_cache: &SkinningCache,
joint_to_player: &HashMap<Entity, Entity>,
) {
let mut joint_cur: HashMap<Entity, Vec<OwnedChannel>> = HashMap::new();
let mut joint_blend: HashMap<Entity, Vec<OwnedChannel>> = HashMap::new();
world
.core
.query()
.with(ANIMATION_PLAYER)
.iter(|_, table, idx| {
let player = &table.animation_player[idx];
if !player.playing || player.play_all {
return;
}
if let Some(current) = player.get_current_clip() {
for channel in ¤t.channels {
if let Some(owned) = build_owned_channel(channel)
&& let Some(target) = player.resolve_target_entity(channel)
{
joint_cur.entry(target).or_default().push(owned);
}
}
}
if let Some(from_clip) = player
.blend_from_clip
.and_then(|index| player.clips.get(index))
{
for channel in &from_clip.channels {
if let Some(owned) = build_owned_channel(channel)
&& let Some(target) = player.resolve_target_entity(channel)
{
joint_blend.entry(target).or_default().push(owned);
}
}
}
});
let mut bones: Vec<GpuAnimBone> = Vec::new();
let mut flat = FlatChannels::default();
let mut layouts: Vec<SkeletonLayout> = Vec::new();
let mut skin_entities: Vec<Entity> =
skinning_cache.entity_skin_indices.keys().copied().collect();
skin_entities.sort_by_key(|entity| (entity.id, entity.generation));
for skin_entity in skin_entities {
let Some(skin) = world.core.get_skin(skin_entity) else {
continue;
};
let Some(&player_entity) = skin
.joints
.iter()
.find_map(|joint| joint_to_player.get(joint))
else {
continue;
};
let skin_index = skinning_cache.entity_skin_indices[&skin_entity];
let base_bone_index = skinning_cache.get_base_bone_index(skin_index);
let local_index_of: HashMap<Entity, usize> = skin
.joints
.iter()
.enumerate()
.map(|(local, joint)| (*joint, local))
.collect();
let mut order: Vec<usize> = (0..skin.joints.len()).collect();
order.sort_by_key(|&local| joint_depth(world, &skin.joints, &local_index_of, local));
let bone_start = bones.len() as u32;
for &local in &order {
let joint = skin.joints[local];
let rest = world
.core
.get_local_transform(joint)
.copied()
.unwrap_or_default();
let parent_local = world
.core
.get_parent(joint)
.and_then(|parent| parent.0)
.and_then(|parent| local_index_of.get(&parent).copied());
let parent_output_index = match parent_local {
Some(parent) => (base_bone_index + parent as u32) as i32,
None => -1,
};
let (cur_start, cur_count) =
append_joint_channels(&mut flat, joint_cur.get(&joint));
let (blend_start, blend_count) =
append_joint_channels(&mut flat, joint_blend.get(&joint));
bones.push(GpuAnimBone {
output_index: base_bone_index + local as u32,
parent_output_index,
cur_channel_start: cur_start,
cur_channel_count: cur_count,
blend_channel_start: blend_start,
blend_channel_count: blend_count,
pad0: 0,
pad1: 0,
rest_translation: [
rest.translation.x,
rest.translation.y,
rest.translation.z,
0.0,
],
rest_rotation: [
rest.rotation.coords.x,
rest.rotation.coords.y,
rest.rotation.coords.z,
rest.rotation.coords.w,
],
rest_scale: [rest.scale.x, rest.scale.y, rest.scale.z, 0.0],
});
}
layouts.push(SkeletonLayout {
skin_entity,
player_entity,
bone_start,
bone_count: skin.joints.len() as u32,
});
}
self.layouts = layouts;
self.bones_buffer = storage_buffer(device, "Anim Bones", bytemuck::cast_slice(&bones));
self.channels_buffer = storage_buffer(
device,
"Anim Channels",
bytemuck::cast_slice(&flat.channels),
);
self.times_buffer = storage_buffer(device, "Anim Times", bytemuck::cast_slice(&flat.times));
self.values_buffer =
storage_buffer(device, "Anim Values", bytemuck::cast_slice(&flat.values));
let skeleton_bytes = self.layouts.len().max(1) * std::mem::size_of::<GpuAnimSkeleton>();
self.skeletons_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Anim Skeletons"),
size: skeleton_bytes as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
self.skeletons_capacity = self.layouts.len();
}
fn upload_skeletons(
&mut self,
queue: &wgpu::Queue,
world: &World,
player_runtime: &HashMap<Entity, PlayerRuntime>,
) {
if self.layouts.is_empty() {
return;
}
let mut skeletons: Vec<GpuAnimSkeleton> = Vec::with_capacity(self.layouts.len());
for layout in &self.layouts {
let (bone_count, time, blend_from_time, blend_factor) =
match player_runtime.get(&layout.player_entity) {
Some(runtime) => (
layout.bone_count,
runtime.time,
runtime.blend_from_time,
runtime.blend_factor,
),
None => (0, 0.0, 0.0, 1.0),
};
let armature_root = skin_armature_root(world, layout.skin_entity);
skeletons.push(GpuAnimSkeleton {
bone_start: layout.bone_start,
bone_count,
time,
blend_from_time,
blend_factor,
pad0: 0,
pad1: 0,
pad2: 0,
armature_root: matrix_to_array(armature_root),
});
}
queue.write_buffer(&self.skeletons_buffer, 0, bytemuck::cast_slice(&skeletons));
}
fn create_bind_group(
&self,
device: &wgpu::Device,
bone_transforms_buffer: &wgpu::Buffer,
) -> wgpu::BindGroup {
device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Animation Compute Bind Group"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: self.skeletons_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: self.bones_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: self.channels_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 3,
resource: self.times_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 4,
resource: self.values_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 5,
resource: bone_transforms_buffer.as_entire_binding(),
},
],
})
}
}
fn gpu_property(property: AnimationProperty) -> Option<u32> {
match property {
AnimationProperty::Translation => Some(PROPERTY_TRANSLATION),
AnimationProperty::Rotation => Some(PROPERTY_ROTATION),
AnimationProperty::Scale => Some(PROPERTY_SCALE),
AnimationProperty::MorphWeights => None,
}
}
fn append_joint_channels(flat: &mut FlatChannels, owned: Option<&Vec<OwnedChannel>>) -> (u32, u32) {
let start = flat.channels.len() as u32;
let mut count = 0u32;
if let Some(channels) = owned {
for channel in channels {
flat.push_owned(channel);
count += 1;
}
}
(start, count)
}
fn joint_depth(
world: &World,
joints: &[Entity],
local_index_of: &HashMap<Entity, usize>,
local: usize,
) -> u32 {
let mut depth = 0u32;
let mut current = joints[local];
while let Some(parent) = world.core.get_parent(current).and_then(|parent| parent.0) {
if !local_index_of.contains_key(&parent) {
break;
}
depth += 1;
current = parent;
if depth > joints.len() as u32 {
break;
}
}
depth
}
fn skin_armature_root(world: &World, skin_entity: Entity) -> nalgebra_glm::Mat4 {
let Some(skin) = world.core.get_skin(skin_entity) else {
return nalgebra_glm::Mat4::identity();
};
let local_index_of: HashMap<Entity, usize> = skin
.joints
.iter()
.enumerate()
.map(|(local, joint)| (*joint, local))
.collect();
skin.joints
.iter()
.find_map(|joint| {
let parent = world.core.get_parent(*joint).and_then(|parent| parent.0)?;
if local_index_of.contains_key(&parent) {
return None;
}
world.core.get_global_transform(parent).map(|gt| gt.0)
})
.unwrap_or_else(nalgebra_glm::Mat4::identity)
}
fn matrix_to_array(matrix: nalgebra_glm::Mat4) -> [[f32; 4]; 4] {
let mut output = [[0.0f32; 4]; 4];
for column in 0..4 {
for row in 0..4 {
output[column][row] = matrix[(row, column)];
}
}
output
}
fn static_signature(
world: &World,
skinning_cache: &SkinningCache,
joint_to_player: &HashMap<Entity, Entity>,
player_clips: &HashMap<Entity, (i64, i64)>,
) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
let mut skin_entities: Vec<Entity> =
skinning_cache.entity_skin_indices.keys().copied().collect();
skin_entities.sort_by_key(|entity| (entity.id, entity.generation));
for skin_entity in skin_entities {
let Some(skin) = world.core.get_skin(skin_entity) else {
continue;
};
let Some(&player_entity) = skin
.joints
.iter()
.find_map(|joint| joint_to_player.get(joint))
else {
continue;
};
skin_entity.id.hash(&mut hasher);
for joint in &skin.joints {
joint.id.hash(&mut hasher);
}
if let Some(clips) = player_clips.get(&player_entity) {
clips.hash(&mut hasher);
}
}
hasher.finish()
}