1use std::collections::HashMap;
16
17use itertools::Itertools;
18use rend3::{
19 types::{
20 glam::{Mat4, Quat, Vec3},
21 SkeletonHandle,
22 },
23 util::typedefs::{FastHashMap, FastHashSet},
24 Renderer,
25};
26use rend3_gltf::{AnimationChannel, GltfSceneInstance, LoadedGltfScene};
27
28#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub struct AnimationIndex(pub usize);
30#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
31pub struct SkinIndex(pub usize);
32#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
33pub struct NodeIndex(pub usize);
34#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
35pub struct JointIndex(pub usize);
36
37pub struct PerSkinData {
40 pub node_to_joint_idx: FastHashMap<NodeIndex, JointIndex>,
44 pub joint_nodes_topological_order: Vec<NodeIndex>,
48 pub skeletons: Vec<SkeletonHandle>,
52}
53
54pub struct AnimationData {
56 pub skin_data: FastHashMap<SkinIndex, PerSkinData>,
59 pub animation_skin_usage: FastHashMap<AnimationIndex, Vec<SkinIndex>>,
63}
64
65impl AnimationData {
66 pub fn from_gltf_scene(scene: &LoadedGltfScene, instance: &GltfSceneInstance) -> Self {
79 let animation_to_joint_nodes: HashMap<AnimationIndex, FastHashSet<NodeIndex>> = scene
83 .animations
84 .iter()
85 .enumerate()
86 .flat_map(|(anim_idx, anim)| {
87 anim.inner
88 .channels
89 .keys()
90 .map(move |node_idx| (AnimationIndex(anim_idx), NodeIndex(*node_idx)))
91 })
92 .into_grouping_map()
93 .collect::<FastHashSet<_>>();
94
95 let mut animation_skin_usage = FastHashMap::<AnimationIndex, Vec<SkinIndex>>::default();
96 for animation_idx in 0..scene.animations.len() {
97 let animation_idx = AnimationIndex(animation_idx);
98 for (skin_index, skin) in scene.skins.iter().enumerate() {
99 let skin_index = SkinIndex(skin_index);
100
101 let anim_affected_nodes = &animation_to_joint_nodes[&animation_idx];
102 if skin
103 .inner
104 .joints
105 .iter()
106 .any(|j| anim_affected_nodes.contains(&NodeIndex(j.inner.node_idx)))
107 {
108 let entry = animation_skin_usage
109 .entry(animation_idx)
110 .or_insert_with(Default::default);
111 entry.push(skin_index);
112 }
113 }
114 }
115
116 let mut skin_data = FastHashMap::default();
117 for (skin_index, skin) in scene.skins.iter().enumerate() {
118 let skin_index = SkinIndex(skin_index);
119
120 let node_to_joint_idx = skin
121 .inner
122 .joints
123 .iter()
124 .enumerate()
125 .map(|(idx, joint)| (NodeIndex(joint.inner.node_idx), JointIndex(idx)))
126 .collect();
127
128 let skin_nodes: Vec<NodeIndex> = skin.inner.joints.iter().map(|j| NodeIndex(j.inner.node_idx)).collect();
130
131 let joint_nodes_topological_order: Vec<NodeIndex> = instance
132 .topological_order
133 .iter()
134 .map(|node_idx| NodeIndex(*node_idx))
135 .filter(|node_idx| skin_nodes.contains(node_idx))
136 .collect();
137
138 let skeletons: Vec<SkeletonHandle> = instance
139 .nodes
140 .iter()
141 .flat_map(|node| &node.inner.object)
142 .flat_map(|object| &object.inner.armature)
143 .filter(|armature| armature.skin_index == skin_index.0)
144 .flat_map(|armature| &armature.skeletons)
145 .cloned()
146 .collect();
147
148 skin_data.insert(
149 skin_index,
150 PerSkinData {
151 node_to_joint_idx,
152 joint_nodes_topological_order,
153 skeletons,
154 },
155 );
156 }
157
158 AnimationData {
159 skin_data,
160 animation_skin_usage,
161 }
162 }
163}
164
165pub trait Lerp {
167 fn lerp(self, other: Self, t: f32) -> Self;
168}
169impl Lerp for Vec3 {
170 fn lerp(self, other: Self, t: f32) -> Self {
171 self.lerp(other, t)
172 }
173}
174impl Lerp for Quat {
175 fn lerp(self, other: Self, t: f32) -> Self {
176 self.lerp(other, t).normalize()
180 }
181}
182
183fn sample_at_time<T: Lerp + Copy>(channel: &AnimationChannel<T>, current_time: f32) -> T {
186 let next_idx = channel
187 .times
188 .iter()
189 .position(|time| *time > current_time)
190 .unwrap_or(channel.times.len() - 1);
191 let prev_idx = next_idx.saturating_sub(1);
192
193 let interp_factor = f32::clamp(
194 (current_time - channel.times[prev_idx]) / (channel.times[next_idx] - channel.times[prev_idx]),
195 0.0,
196 1.0,
197 );
198
199 channel.values[prev_idx].lerp(channel.values[next_idx], interp_factor)
200}
201
202pub fn pose_animation_frame(
206 renderer: &Renderer,
207 scene: &LoadedGltfScene,
208 instance: &GltfSceneInstance,
209 animation_data: &AnimationData,
210 animation_index: usize,
211 time: f32,
212) {
213 let animation = &scene.animations[animation_index];
214 let time = time.clamp(0.0, animation.inner.duration);
215
216 for (skin_index, per_skin_data) in &animation_data.skin_data {
217 let skin = &scene.skins[skin_index.0];
218 let inv_bind_mats = &skin.inner.inverse_bind_matrices;
219
220 let mut joint_local_matrices = vec![Mat4::IDENTITY; inv_bind_mats.len()];
222
223 let node_to_joint_idx = &per_skin_data.node_to_joint_idx;
224
225 for (&node_idx, channels) in &animation.inner.channels {
227 let local_transform = instance.nodes[node_idx].inner.local_transform;
230 let (bind_scale, bind_rotation, bind_translation) = local_transform.to_scale_rotation_translation();
231
232 let translation = channels
233 .translation
234 .as_ref()
235 .map(|tra| sample_at_time(tra, time))
236 .unwrap_or(bind_translation);
237 let rotation = channels
238 .rotation
239 .as_ref()
240 .map(|rot| sample_at_time(rot, time))
241 .unwrap_or(bind_rotation);
242 let scale = channels
243 .scale
244 .as_ref()
245 .map(|sca| sample_at_time(sca, time))
246 .unwrap_or(bind_scale);
247
248 let matrix = Mat4::from_scale_rotation_translation(scale, rotation, translation);
249 let joint_idx = node_to_joint_idx[&NodeIndex(node_idx)];
250 joint_local_matrices[joint_idx.0] = matrix;
251 }
252
253 let mut global_joint_transforms = vec![Mat4::IDENTITY; inv_bind_mats.len()];
254
255 for node_idx in &per_skin_data.joint_nodes_topological_order {
257 let node = &instance.nodes[node_idx.0].inner;
258 let joint_idx = node_to_joint_idx[node_idx];
259 if let Some(parent_joint_idx) = node.parent.map(|pi| node_to_joint_idx.get(&NodeIndex(pi))) {
260 let parent_transform = parent_joint_idx
263 .map(|p| global_joint_transforms[p.0])
264 .unwrap_or(Mat4::IDENTITY);
265 let current_transform = joint_local_matrices[joint_idx.0];
266
267 global_joint_transforms[joint_idx.0] = parent_transform * current_transform;
268 } else {
269 global_joint_transforms[joint_idx.0] = joint_local_matrices[joint_idx.0];
270 }
271 }
272
273 for skeleton in &per_skin_data.skeletons {
275 renderer.set_skeleton_joint_transforms(skeleton, &global_joint_transforms, inv_bind_mats);
276 }
277 }
278}